Skip to content

Commit ea6a465

Browse files
authored
NumPy zonal stats: return a data array of calculated stats (#685)
* zonal_stats returns a DataArray * flake8
1 parent b02d5f8 commit ea6a465

File tree

2 files changed

+195
-42
lines changed

2 files changed

+195
-42
lines changed

xrspatial/tests/test_zonal.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from xrspatial import zonal_stats as stats
1212
from xrspatial.zonal import regions
1313

14-
from .general_checks import create_test_raster, has_cuda_and_cupy
14+
from .general_checks import create_test_raster, general_output_checks, has_cuda_and_cupy
1515

1616

1717
@pytest.fixture
@@ -60,6 +60,40 @@ def result_default_stats():
6060
return expected_result
6161

6262

63+
@pytest.fixture
64+
def result_default_stats_dataarray():
65+
expected_result = np.array(
66+
[[[0., 0., 1., 1., 2., 2., 2.4, 2.4],
67+
[0., 0., 1., 1., 2., 2., 2.4, 2.4],
68+
[0., 0., 1., 1., 2., np.nan, 2.4, 2.4]],
69+
70+
[[0., 0., 1., 1., 2., 2., 3., 3.],
71+
[0., 0., 1., 1., 2., 2., 3., 3.],
72+
[0., 0., 1., 1., 2., np.nan, 3., 3.]],
73+
74+
[[0., 0., 1., 1., 2., 2., 0., 0.],
75+
[0., 0., 1., 1., 2., 2., 0., 0.],
76+
[0., 0., 1., 1., 2., np.nan, 0., 0.]],
77+
78+
[[0., 0., 6., 6., 8., 8., 12., 12.],
79+
[0., 0., 6., 6., 8., 8., 12., 12.],
80+
[0., 0., 6., 6., 8., np.nan, 12., 12.]],
81+
82+
[[0., 0., 0., 0., 0., 0., 1.2, 1.2],
83+
[0., 0., 0., 0., 0., 0., 1.2, 1.2],
84+
[0., 0., 0., 0., 0., np.nan, 1.2, 1.2]],
85+
86+
[[0., 0., 0., 0., 0., 0., 1.44, 1.44],
87+
[0., 0., 0., 0., 0., 0., 1.44, 1.44],
88+
[0., 0., 0., 0., 0., np.nan, 1.44, 1.44]],
89+
90+
[[5., 5., 6., 6., 4., 4., 5., 5.],
91+
[5., 5., 6., 6., 4., 4., 5., 5.],
92+
[5., 5., 6., 6., 4., np.nan, 5., 5.]]]
93+
)
94+
return expected_result
95+
96+
6397
@pytest.fixture
6498
def result_zone_ids_stats():
6599
zone_ids = [0, 3]
@@ -76,6 +110,41 @@ def result_zone_ids_stats():
76110
return zone_ids, expected_result
77111

78112

113+
@pytest.fixture
114+
def result_zone_ids_stats_dataarray():
115+
zone_ids = [0, 3]
116+
expected_result = np.array(
117+
[[[0., 0., np.nan, np.nan, np.nan, np.nan, 2.4, 2.4],
118+
[0., 0., np.nan, np.nan, np.nan, np.nan, 2.4, 2.4],
119+
[0., 0., np.nan, np.nan, np.nan, np.nan, 2.4, 2.4]],
120+
121+
[[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.],
122+
[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.],
123+
[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.]],
124+
125+
[[0., 0., np.nan, np.nan, np.nan, np.nan, 0., 0.],
126+
[0., 0., np.nan, np.nan, np.nan, np.nan, 0., 0.],
127+
[0., 0., np.nan, np.nan, np.nan, np.nan, 0., 0.]],
128+
129+
[[0., 0., np.nan, np.nan, np.nan, np.nan, 12., 12.],
130+
[0., 0., np.nan, np.nan, np.nan, np.nan, 12., 12.],
131+
[0., 0., np.nan, np.nan, np.nan, np.nan, 12., 12.]],
132+
133+
[[0., 0., np.nan, np.nan, np.nan, np.nan, 1.2, 1.2],
134+
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.2, 1.2],
135+
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.2, 1.2]],
136+
137+
[[0., 0., np.nan, np.nan, np.nan, np.nan, 1.44, 1.44],
138+
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.44, 1.44],
139+
[0., 0., np.nan, np.nan, np.nan, np.nan, 1.44, 1.44]],
140+
141+
[[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.],
142+
[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.],
143+
[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.]]])
144+
145+
return zone_ids, expected_result
146+
147+
79148
def _double_sum(values):
80149
return values.sum() * 2
81150

@@ -96,6 +165,22 @@ def result_custom_stats():
96165
return nodata_values, zone_ids, expected_result
97166

98167

168+
@pytest.fixture
169+
def result_custom_stats_dataarray():
170+
zone_ids = [1, 2]
171+
nodata_values = 0
172+
expected_result = np.array(
173+
[[[np.nan, np.nan, 12., 12., 16., 16., np.nan, np.nan],
174+
[np.nan, np.nan, 12., 12., 16., 16., np.nan, np.nan],
175+
[np.nan, np.nan, 12., 12., 16., np.nan, np.nan, np.nan]],
176+
177+
[[np.nan, np.nan, 0., 0., 0., 0., np.nan, np.nan],
178+
[np.nan, np.nan, 0., 0., 0., 0., np.nan, np.nan],
179+
[np.nan, np.nan, 0., 0., 0., np.nan, np.nan, np.nan]]]
180+
)
181+
return nodata_values, zone_ids, expected_result
182+
183+
99184
@pytest.fixture
100185
def result_count_crosstab_2d():
101186
zone_ids = [1, 2, 3]
@@ -174,6 +259,22 @@ def test_default_stats(backend, data_zones, data_values_2d, result_default_stats
174259
check_results(backend, df_result, result_default_stats)
175260

176261

262+
@pytest.mark.parametrize("backend", ['numpy'])
263+
def test_default_stats_dataarray(
264+
backend, data_zones, data_values_2d, result_default_stats_dataarray
265+
):
266+
dataarray_result = stats(
267+
zones=data_zones, values=data_values_2d, return_type='xarray.DataArray'
268+
)
269+
general_output_checks(
270+
data_values_2d,
271+
dataarray_result,
272+
result_default_stats_dataarray,
273+
verify_dtype=False,
274+
verify_attrs=False,
275+
)
276+
277+
177278
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy'])
178279
def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_stats):
179280
if backend == 'cupy' and not has_cuda_and_cupy():
@@ -184,6 +285,19 @@ def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_sta
184285
check_results(backend, df_result, expected_result)
185286

186287

288+
@pytest.mark.parametrize("backend", ['numpy'])
289+
def test_zone_ids_stats_dataarray(
290+
backend, data_zones, data_values_2d, result_zone_ids_stats_dataarray
291+
):
292+
zone_ids, expected_result = result_zone_ids_stats_dataarray
293+
dataarray_result = stats(
294+
zones=data_zones, values=data_values_2d, zone_ids=zone_ids, return_type='xarray.DataArray'
295+
)
296+
general_output_checks(
297+
data_values_2d, dataarray_result, expected_result, verify_dtype=False, verify_attrs=False
298+
)
299+
300+
187301
@pytest.mark.parametrize("backend", ['numpy', 'cupy'])
188302
def test_custom_stats(backend, data_zones, data_values_2d, result_custom_stats):
189303
# ---- custom stats (NumPy and CuPy only) ----
@@ -203,6 +317,23 @@ def test_custom_stats(backend, data_zones, data_values_2d, result_custom_stats):
203317
check_results(backend, df_result, expected_result)
204318

205319

320+
@pytest.mark.parametrize("backend", ['numpy'])
321+
def test_custom_stats_dataarray(backend, data_zones, data_values_2d, result_custom_stats_dataarray):
322+
# ---- custom stats returns a xr.DataArray (NumPy only) ----
323+
custom_stats = {
324+
'double_sum': _double_sum,
325+
'range': _range,
326+
}
327+
nodata_values, zone_ids, expected_result = result_custom_stats_dataarray
328+
dataarray_result = stats(
329+
zones=data_zones, values=data_values_2d, stats_funcs=custom_stats,
330+
zone_ids=zone_ids, nodata_values=nodata_values, return_type='xarray.DataArray'
331+
)
332+
general_output_checks(
333+
data_values_2d, dataarray_result, expected_result, verify_dtype=False, verify_attrs=False
334+
)
335+
336+
206337
@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy'])
207338
def test_count_crosstab_2d(backend, data_zones, data_values_2d, result_count_crosstab_2d):
208339
zone_ids, cat_ids, expected_result = result_count_crosstab_2d

xrspatial/zonal.py

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _sort_and_stride(zones, values, unique_zones):
105105
sorted_zones = sorted_zones[np.isfinite(sorted_zones)]
106106
zone_breaks = _strides(sorted_zones, unique_zones)
107107

108-
return values_by_zones, zone_breaks
108+
return sorted_indices, values_by_zones, zone_breaks
109109

110110

111111
def _calc_stats(
@@ -123,8 +123,7 @@ def _calc_stats(
123123
if unique_zones[i] in zone_ids:
124124
zone_values = values_by_zones[start:end]
125125
# filter out non-finite and nodata_values
126-
zone_values = zone_values[
127-
np.isfinite(zone_values) & (zone_values != nodata_values)]
126+
zone_values = zone_values[np.isfinite(zone_values) & (zone_values != nodata_values)]
128127
if len(zone_values) > 0:
129128
results[i] = func(zone_values)
130129
start = end
@@ -141,13 +140,8 @@ def _single_stats_func(
141140
nodata_values: Union[int, float] = None,
142141
) -> pd.DataFrame:
143142

144-
values_by_zones, zone_breaks = _sort_and_stride(
145-
zones_block, values_block, unique_zones
146-
)
147-
results = _calc_stats(
148-
values_by_zones, zone_breaks,
149-
unique_zones, zone_ids, func, nodata_values
150-
)
143+
_, values_by_zones, zone_breaks = _sort_and_stride(zones_block, values_block, unique_zones)
144+
results = _calc_stats(values_by_zones, zone_breaks, unique_zones, zone_ids, func, nodata_values)
151145
return results
152146

153147

@@ -224,19 +218,15 @@ def _stats_dask_numpy(
224218
stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count'])
225219
if 'std' in stats_funcs:
226220
stats_dict['std'] = _dask_std(
227-
stats_dict['sum_squares'], stats_dict['sum'] ** 2,
228-
stats_dict['count']
221+
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
229222
)
230223
if 'var' in stats_funcs:
231224
stats_dict['var'] = _dask_var(
232-
stats_dict['sum_squares'], stats_dict['sum'] ** 2,
233-
stats_dict['count']
225+
stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count']
234226
)
235227

236228
# generate dask dataframe
237-
stats_df = dd.concat(
238-
[dd.from_dask_array(s) for s in stats_dict.values()], axis=1
239-
)
229+
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1)
240230
# name columns
241231
stats_df.columns = stats_dict.keys()
242232
# select columns
@@ -259,7 +249,8 @@ def _stats_numpy(
259249
zone_ids: List[Union[int, float]],
260250
stats_funcs: Dict,
261251
nodata_values: Union[int, float],
262-
) -> pd.DataFrame:
252+
return_type: str,
253+
) -> Union[pd.DataFrame, np.ndarray]:
263254

264255
# find ids for all zones
265256
unique_zones = np.unique(zones[np.isfinite(zones)])
@@ -271,23 +262,40 @@ def _stats_numpy(
271262
# remove zones that do not exist in `zones` raster
272263
zone_ids = [z for z in zone_ids if z in unique_zones]
273264

274-
selected_indexes = [i for i, z in enumerate(unique_zones) if z in zone_ids]
275-
values_by_zones, zone_breaks = _sort_and_stride(
276-
zones, values, unique_zones
277-
)
278-
279-
stats_dict = {}
280-
stats_dict["zone"] = zone_ids
281-
for stats in stats_funcs:
282-
func = stats_funcs.get(stats)
283-
stats_dict[stats] = _calc_stats(
284-
values_by_zones, zone_breaks,
285-
unique_zones, zone_ids, func, nodata_values
286-
)
287-
stats_dict[stats] = stats_dict[stats][selected_indexes]
265+
sorted_indices, values_by_zones, zone_breaks = _sort_and_stride(zones, values, unique_zones)
266+
if return_type == 'pandas.DataFrame':
267+
stats_dict = {}
268+
stats_dict["zone"] = zone_ids
269+
selected_indexes = [i for i, z in enumerate(unique_zones) if z in zone_ids]
270+
for stats in stats_funcs:
271+
func = stats_funcs.get(stats)
272+
stats_dict[stats] = _calc_stats(
273+
values_by_zones, zone_breaks,
274+
unique_zones, zone_ids, func, nodata_values
275+
)
276+
stats_dict[stats] = stats_dict[stats][selected_indexes]
277+
result = pd.DataFrame(stats_dict)
288278

289-
stats_df = pd.DataFrame(stats_dict)
290-
return stats_df
279+
else:
280+
result = np.full((len(stats_funcs), values.size), np.nan)
281+
zone_ids_map = {z: i for i, z in enumerate(unique_zones) if z in zone_ids}
282+
stats_id = 0
283+
for stats in stats_funcs:
284+
func = stats_funcs.get(stats)
285+
stats_results = _calc_stats(
286+
values_by_zones, zone_breaks,
287+
unique_zones, zone_ids, func, nodata_values
288+
)
289+
for zone in zone_ids:
290+
iz = zone_ids_map[zone] # position of zone in unique_zones
291+
if iz == 0:
292+
zs = sorted_indices[: zone_breaks[iz]]
293+
else:
294+
zs = sorted_indices[zone_breaks[iz-1]: zone_breaks[iz]]
295+
result[stats_id][zs] = stats_results[iz]
296+
stats_id += 1
297+
result = result.reshape(len(stats_funcs), *values.shape)
298+
return result
291299

292300

293301
def _stats_cupy(
@@ -391,7 +399,8 @@ def stats(
391399
"count",
392400
],
393401
nodata_values: Union[int, float] = None,
394-
) -> Union[pd.DataFrame, dd.DataFrame]:
402+
return_type: str = 'pandas.DataFrame',
403+
) -> Union[pd.DataFrame, dd.DataFrame, xr.DataArray]:
395404
"""
396405
Calculate summary statistics for each zone defined by a `zones`
397406
dataset, based on `values` aggregate.
@@ -438,6 +447,11 @@ def stats(
438447
Cells with `nodata_values` do not belong to any zone,
439448
and thus excluded from calculation.
440449
450+
return_type: str, default='pandas.DataFrame'
451+
Format of returned data. If `zones` and `values` numpy backed xarray DataArray,
452+
allowed values are 'pandas.DataFrame', and 'xarray.DataArray'.
453+
Otherwise, only 'pandas.DataFrame' is supported.
454+
441455
Returns
442456
-------
443457
stats_df : Union[pandas.DataFrame, dask.dataframe.DataFrame]
@@ -568,17 +582,25 @@ def stats(
568582
stats_funcs_dict = stats_funcs.copy()
569583

570584
mapper = ArrayTypeFunctionMapping(
571-
numpy_func=_stats_numpy,
585+
numpy_func=lambda *args: _stats_numpy(*args, return_type=return_type),
572586
dask_func=_stats_dask_numpy,
573587
cupy_func=_stats_cupy,
574588
dask_cupy_func=lambda *args: not_implemented_func(
575589
*args, messages='stats() does not support dask with cupy backed DataArray' # noqa
576590
),
577591
)
578-
stats_df = mapper(values)(
579-
zones.data, values.data, zone_ids, stats_funcs_dict, nodata_values
592+
result = mapper(values)(
593+
zones.data, values.data, zone_ids, stats_funcs_dict, nodata_values,
580594
)
581-
return stats_df
595+
596+
if return_type == 'xarray.DataArray':
597+
return xr.DataArray(
598+
result,
599+
coords={'stats': list(stats_funcs_dict.keys()), **values.coords},
600+
dims=('stats', *values.dims),
601+
attrs=values.attrs
602+
)
603+
return result
582604

583605

584606
def _find_cats(values, cat_ids, nodata_values):
@@ -680,7 +702,7 @@ def _crosstab_numpy(
680702
for cat in cat_ids:
681703
crosstab_dict[cat] = []
682704

683-
values_by_zones, zone_breaks = _sort_and_stride(
705+
_, values_by_zones, zone_breaks = _sort_and_stride(
684706
zones, values, unique_zones
685707
)
686708

@@ -731,7 +753,7 @@ def _single_chunk_crosstab(
731753
for cat in cat_ids:
732754
results[cat] = []
733755

734-
values_by_zones, zone_breaks = _sort_and_stride(
756+
_, values_by_zones, zone_breaks = _sort_and_stride(
735757
zones_block, values_block, unique_zones
736758
)
737759

0 commit comments

Comments
 (0)