Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 56 additions & 87 deletions xrspatial/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
class cupy(object):
ndarray = False

from xrspatial.convolution import (_available_memory_bytes, _convolve_2d_cupy, _convolve_2d_numpy,
_promote_float, convolve_2d, custom_kernel)
from xrspatial.convolution import (_available_memory_bytes, _promote_float, convolve_2d,
custom_kernel)
from xrspatial.dataset_support import supports_dataset
from xrspatial.utils import (ArrayTypeFunctionMapping, _boundary_to_dask, _pad_array,
_validate_boundary, _validate_raster, _validate_scalar, cuda_args,
Expand Down Expand Up @@ -1346,12 +1346,12 @@ def _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum,
# large numbers and loses precision in float32 for big rasters / weights.
weight_sum = weight_sum.astype(np.float64)
sq_weight_sum = sq_weight_sum.astype(np.float64)
numerator = weighted_sum.astype(np.float64) - float(global_mean) * weight_sum
numerator = weighted_sum.astype(np.float64) - global_mean * weight_sum
variance_term = (n * sq_weight_sum - weight_sum * weight_sum) / (n - 1)
# Guard against tiny negatives from float rounding and the degenerate
# single-cell neighborhood (variance_term == 0) before the sqrt.
variance_term = np.where(variance_term > 0, variance_term, np.nan)
denominator = float(global_std) * np.sqrt(variance_term)
denominator = global_std * np.sqrt(variance_term)
z = numerator / denominator
return np.where(np.isfinite(z), z, 0.0).astype(np.float32)

Expand Down Expand Up @@ -1395,99 +1395,68 @@ def _hotspots_numpy(raster, kernel, boundary='nan'):


def _hotspots_dask_numpy(raster, kernel, boundary='nan'):
data = raster.data
if not np.issubdtype(data.dtype, np.floating):
data = data.astype(np.float32)

# Pass 1: eagerly compute global Gi* terms (three scalars).
# This reads all chunks once, produces a few bytes, then frees all
# intermediate state -- no barrier that would force materialization
# of the full convolution output.
valid = ~da.isnan(data)
global_mean, global_std, n = da.compute(
da.nanmean(data), da.nanstd(data), valid.sum())
global_mean = np.float32(global_mean)
global_std = np.float32(global_std)
n = int(n)
_gistar_global_stats(global_mean, global_std, n)
# Match the numpy path: compute in float32 so the convolution and the
# float32 map_overlap meta agree regardless of input dtype.
data = raster.data.astype(np.float32)

kernel = kernel.astype(np.float32)
pad_h = kernel.shape[0] // 2
pad_w = kernel.shape[1] // 2
# Global Gi* terms stay lazy: 0-d dask arrays that broadcast into the
# z-score below. Nothing is computed during graph construction.
valid = ~da.isnan(data)
global_mean = da.nanmean(data)
global_std = da.nanstd(data)
n = valid.sum()

# Per-cell Gi* convolution terms via convolve_2d's lazy dask path,
# mirroring _gistar_convolutions_numpy on the single-array backend so
# the dask result matches numpy. NaN cells are excluded via the
# validity mask.
valid_f = valid.astype(np.float32)
filled = da.where(valid, data, np.float32(0.0))
weighted_sum = convolve_2d(filled, kernel, boundary)
weight_sum = convolve_2d(valid_f, kernel, boundary)
sq_weight_sum = convolve_2d(valid_f, kernel * kernel, boundary)

# Pass 2: fuse the three convolutions + Gi* z-score + classification
# into one map_overlap call. Each chunk reads source + halo, produces
# int8 output, and frees all intermediates immediately.
_func = partial(
_hotspots_chunk,
kernel=kernel,
global_mean=global_mean,
global_std=global_std,
n=n,
)
out = data.map_overlap(
_func,
depth=(pad_h, pad_w),
boundary=_boundary_to_dask(boundary),
meta=np.array((), dtype=np.int8),
)
# Gi* z-score via broadcast of the lazy 0-d global terms, then classify
# per block.
z_array = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum,
global_mean, global_std, n)
out = z_array.map_blocks(_calc_hotspots_numpy,
meta=np.array((), dtype=np.int8))
return out


def _hotspots_chunk(chunk, kernel, global_mean, global_std, n):
"""Fused per-chunk: convolve Gi* terms -> z-score -> classify."""
valid = (~np.isnan(chunk)).astype(np.float32)
filled = np.where(valid > 0, chunk, np.float32(0.0))
weighted_sum = _convolve_2d_numpy(filled, kernel)
weight_sum = _convolve_2d_numpy(valid, kernel)
sq_weight_sum = _convolve_2d_numpy(valid, kernel * kernel)
z = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum,
global_mean, global_std, n)
return _calc_hotspots_numpy(z)
def _calc_hotspots_cupy(z):
"""Per-chunk GPU classification of a z-score array."""
out = cupy.zeros_like(z, dtype=cupy.int8)
griddim, blockdim = cuda_args(z.shape)
_run_gpu_hotspots[griddim, blockdim](z, out)
return out


def _hotspots_dask_cupy(raster, kernel, boundary='nan'):
data = raster.data
if not cupy.issubdtype(data.dtype, cupy.floating):
data = data.astype(cupy.float32)

# Pass 1: global Gi* terms (three scalars, eager)
valid_global = ~da.isnan(data)
global_mean, global_std, n = da.compute(
da.nanmean(data), da.nanstd(data), valid_global.sum())
global_mean = np.float32(float(global_mean))
global_std = np.float32(float(global_std))
n = int(n)
_gistar_global_stats(global_mean, global_std, n)
# Match the numpy path: compute in float32 so the convolution and the
# float32 map_overlap meta agree regardless of input dtype.
data = raster.data.astype(cupy.float32)

kernel = kernel.astype(np.float32)
sq_kernel = kernel * kernel
pad_h = kernel.shape[0] // 2
pad_w = kernel.shape[1] // 2
# Global Gi* terms stay lazy: 0-d dask arrays that broadcast into the
# z-score below. Nothing is computed during graph construction.
valid = ~da.isnan(data)
global_mean = da.nanmean(data)
global_std = da.nanstd(data)
n = valid.sum()

# Per-cell Gi* convolution terms via convolve_2d's lazy dask+cupy path;
# each chunk stays on the device (no host round trip).
valid_f = valid.astype(cupy.float32)
filled = da.where(valid, data, cupy.float32(0.0))
weighted_sum = convolve_2d(filled, kernel, boundary)
weight_sum = convolve_2d(valid_f, kernel, boundary)
sq_weight_sum = convolve_2d(valid_f, kernel * kernel, boundary)

# Pass 2: fuse the three convolutions + Gi* z-score + classification,
# all on the GPU. Reuse the _run_gpu_hotspots kernel (same as the
# single-GPU path) so each chunk stays on the device -- no host round
# trip per chunk.
def _chunk_fn(chunk):
valid = (~cupy.isnan(chunk)).astype(cupy.float32)
filled = cupy.where(valid > 0, chunk, cupy.float32(0.0))
weighted_sum = _convolve_2d_cupy(filled, kernel)
weight_sum = _convolve_2d_cupy(valid, kernel)
sq_weight_sum = _convolve_2d_cupy(valid, sq_kernel)
z = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum,
global_mean, global_std, n)
out = cupy.zeros_like(z, dtype=cupy.int8)
griddim, blockdim = cuda_args(z.shape)
_run_gpu_hotspots[griddim, blockdim](z, out)
return out

out = data.map_overlap(
_chunk_fn,
depth=(pad_h, pad_w),
boundary=_boundary_to_dask(boundary, is_cupy=True),
meta=cupy.array((), dtype=cupy.int8),
)
z_array = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum,
global_mean, global_std, n)
out = z_array.map_blocks(_calc_hotspots_cupy,
meta=cupy.array((), dtype=cupy.int8))
return out


Expand Down
30 changes: 30 additions & 0 deletions xrspatial/tests/test_dask_laziness.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
try:
import dask
import dask.array as da
import dask.callbacks # noqa: F401 (registers dask.callbacks.Callback)
except ImportError:
da = None

Expand All @@ -32,6 +33,20 @@ def _is_lazy(result):
return False


class _TaskCounter(dask.callbacks.Callback):
"""Count dask tasks executed while the context is active.

Used to assert that a function builds its graph lazily (zero tasks run
during the call) rather than eagerly triggering computation.
"""

def __init__(self):
self.count = 0

def _posttask(self, key, result, dsk, state, worker_id):
self.count += 1


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -117,6 +132,21 @@ def test_hotspots(self, elev):
kernel = np.ones((3, 3), dtype=np.float32) / 9
assert _is_lazy(hotspots(elev, kernel))

def test_hotspots_no_eager_compute(self, elev):
# hotspots() must build the graph without running any dask tasks.
# Previously it eagerly computed global mean/std, executing ~12
# tasks on the call itself (issue #2772).
from xrspatial.focal import hotspots
kernel = np.ones((3, 3), dtype=np.float32) / 9
with _TaskCounter() as counter:
result = hotspots(elev, kernel)
assert counter.count == 0, (
f"hotspots() ran {counter.count} dask tasks on call; expected 0 "
f"(should stay lazy)"
)
# The graph must still compute to a real result.
assert result.compute().shape == elev.shape


# ---------------------------------------------------------------------------
# Classification (partially lazy -- returns dask after computing small stats)
Expand Down
Loading