1111from xrspatial import zonal_stats as stats
1212from xrspatial .zonal import regions
1313
14- from .general_checks import create_test_raster , has_cuda_and_cupy
14+ from .general_checks import create_test_raster , general_output_checks , has_cuda_and_cupy
1515
1616
1717@pytest .fixture
@@ -60,6 +60,40 @@ def result_default_stats():
6060 return expected_result
6161
6262
63+ @pytest .fixture
64+ def result_default_stats_dataarray ():
65+ expected_result = np .array (
66+ [[[0. , 0. , 1. , 1. , 2. , 2. , 2.4 , 2.4 ],
67+ [0. , 0. , 1. , 1. , 2. , 2. , 2.4 , 2.4 ],
68+ [0. , 0. , 1. , 1. , 2. , np .nan , 2.4 , 2.4 ]],
69+
70+ [[0. , 0. , 1. , 1. , 2. , 2. , 3. , 3. ],
71+ [0. , 0. , 1. , 1. , 2. , 2. , 3. , 3. ],
72+ [0. , 0. , 1. , 1. , 2. , np .nan , 3. , 3. ]],
73+
74+ [[0. , 0. , 1. , 1. , 2. , 2. , 0. , 0. ],
75+ [0. , 0. , 1. , 1. , 2. , 2. , 0. , 0. ],
76+ [0. , 0. , 1. , 1. , 2. , np .nan , 0. , 0. ]],
77+
78+ [[0. , 0. , 6. , 6. , 8. , 8. , 12. , 12. ],
79+ [0. , 0. , 6. , 6. , 8. , 8. , 12. , 12. ],
80+ [0. , 0. , 6. , 6. , 8. , np .nan , 12. , 12. ]],
81+
82+ [[0. , 0. , 0. , 0. , 0. , 0. , 1.2 , 1.2 ],
83+ [0. , 0. , 0. , 0. , 0. , 0. , 1.2 , 1.2 ],
84+ [0. , 0. , 0. , 0. , 0. , np .nan , 1.2 , 1.2 ]],
85+
86+ [[0. , 0. , 0. , 0. , 0. , 0. , 1.44 , 1.44 ],
87+ [0. , 0. , 0. , 0. , 0. , 0. , 1.44 , 1.44 ],
88+ [0. , 0. , 0. , 0. , 0. , np .nan , 1.44 , 1.44 ]],
89+
90+ [[5. , 5. , 6. , 6. , 4. , 4. , 5. , 5. ],
91+ [5. , 5. , 6. , 6. , 4. , 4. , 5. , 5. ],
92+ [5. , 5. , 6. , 6. , 4. , np .nan , 5. , 5. ]]]
93+ )
94+ return expected_result
95+
96+
6397@pytest .fixture
6498def result_zone_ids_stats ():
6599 zone_ids = [0 , 3 ]
@@ -76,6 +110,41 @@ def result_zone_ids_stats():
76110 return zone_ids , expected_result
77111
78112
113+ @pytest .fixture
114+ def result_zone_ids_stats_dataarray ():
115+ zone_ids = [0 , 3 ]
116+ expected_result = np .array (
117+ [[[0. , 0. , np .nan , np .nan , np .nan , np .nan , 2.4 , 2.4 ],
118+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 2.4 , 2.4 ],
119+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 2.4 , 2.4 ]],
120+
121+ [[0. , 0. , np .nan , np .nan , np .nan , np .nan , 3. , 3. ],
122+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 3. , 3. ],
123+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 3. , 3. ]],
124+
125+ [[0. , 0. , np .nan , np .nan , np .nan , np .nan , 0. , 0. ],
126+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 0. , 0. ],
127+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 0. , 0. ]],
128+
129+ [[0. , 0. , np .nan , np .nan , np .nan , np .nan , 12. , 12. ],
130+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 12. , 12. ],
131+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 12. , 12. ]],
132+
133+ [[0. , 0. , np .nan , np .nan , np .nan , np .nan , 1.2 , 1.2 ],
134+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 1.2 , 1.2 ],
135+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 1.2 , 1.2 ]],
136+
137+ [[0. , 0. , np .nan , np .nan , np .nan , np .nan , 1.44 , 1.44 ],
138+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 1.44 , 1.44 ],
139+ [0. , 0. , np .nan , np .nan , np .nan , np .nan , 1.44 , 1.44 ]],
140+
141+ [[5. , 5. , np .nan , np .nan , np .nan , np .nan , 5. , 5. ],
142+ [5. , 5. , np .nan , np .nan , np .nan , np .nan , 5. , 5. ],
143+ [5. , 5. , np .nan , np .nan , np .nan , np .nan , 5. , 5. ]]])
144+
145+ return zone_ids , expected_result
146+
147+
79148def _double_sum (values ):
80149 return values .sum () * 2
81150
@@ -96,6 +165,22 @@ def result_custom_stats():
96165 return nodata_values , zone_ids , expected_result
97166
98167
168+ @pytest .fixture
169+ def result_custom_stats_dataarray ():
170+ zone_ids = [1 , 2 ]
171+ nodata_values = 0
172+ expected_result = np .array (
173+ [[[np .nan , np .nan , 12. , 12. , 16. , 16. , np .nan , np .nan ],
174+ [np .nan , np .nan , 12. , 12. , 16. , 16. , np .nan , np .nan ],
175+ [np .nan , np .nan , 12. , 12. , 16. , np .nan , np .nan , np .nan ]],
176+
177+ [[np .nan , np .nan , 0. , 0. , 0. , 0. , np .nan , np .nan ],
178+ [np .nan , np .nan , 0. , 0. , 0. , 0. , np .nan , np .nan ],
179+ [np .nan , np .nan , 0. , 0. , 0. , np .nan , np .nan , np .nan ]]]
180+ )
181+ return nodata_values , zone_ids , expected_result
182+
183+
99184@pytest .fixture
100185def result_count_crosstab_2d ():
101186 zone_ids = [1 , 2 , 3 ]
@@ -174,6 +259,22 @@ def test_default_stats(backend, data_zones, data_values_2d, result_default_stats
174259 check_results (backend , df_result , result_default_stats )
175260
176261
262+ @pytest .mark .parametrize ("backend" , ['numpy' ])
263+ def test_default_stats_dataarray (
264+ backend , data_zones , data_values_2d , result_default_stats_dataarray
265+ ):
266+ dataarray_result = stats (
267+ zones = data_zones , values = data_values_2d , return_type = 'xarray.DataArray'
268+ )
269+ general_output_checks (
270+ data_values_2d ,
271+ dataarray_result ,
272+ result_default_stats_dataarray ,
273+ verify_dtype = False ,
274+ verify_attrs = False ,
275+ )
276+
277+
177278@pytest .mark .parametrize ("backend" , ['numpy' , 'dask+numpy' , 'cupy' ])
178279def test_zone_ids_stats (backend , data_zones , data_values_2d , result_zone_ids_stats ):
179280 if backend == 'cupy' and not has_cuda_and_cupy ():
@@ -184,6 +285,19 @@ def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_sta
184285 check_results (backend , df_result , expected_result )
185286
186287
288+ @pytest .mark .parametrize ("backend" , ['numpy' ])
289+ def test_zone_ids_stats_dataarray (
290+ backend , data_zones , data_values_2d , result_zone_ids_stats_dataarray
291+ ):
292+ zone_ids , expected_result = result_zone_ids_stats_dataarray
293+ dataarray_result = stats (
294+ zones = data_zones , values = data_values_2d , zone_ids = zone_ids , return_type = 'xarray.DataArray'
295+ )
296+ general_output_checks (
297+ data_values_2d , dataarray_result , expected_result , verify_dtype = False , verify_attrs = False
298+ )
299+
300+
187301@pytest .mark .parametrize ("backend" , ['numpy' , 'cupy' ])
188302def test_custom_stats (backend , data_zones , data_values_2d , result_custom_stats ):
189303 # ---- custom stats (NumPy and CuPy only) ----
@@ -203,6 +317,23 @@ def test_custom_stats(backend, data_zones, data_values_2d, result_custom_stats):
203317 check_results (backend , df_result , expected_result )
204318
205319
320+ @pytest .mark .parametrize ("backend" , ['numpy' ])
321+ def test_custom_stats_dataarray (backend , data_zones , data_values_2d , result_custom_stats_dataarray ):
322+ # ---- custom stats returns a xr.DataArray (NumPy only) ----
323+ custom_stats = {
324+ 'double_sum' : _double_sum ,
325+ 'range' : _range ,
326+ }
327+ nodata_values , zone_ids , expected_result = result_custom_stats_dataarray
328+ dataarray_result = stats (
329+ zones = data_zones , values = data_values_2d , stats_funcs = custom_stats ,
330+ zone_ids = zone_ids , nodata_values = nodata_values , return_type = 'xarray.DataArray'
331+ )
332+ general_output_checks (
333+ data_values_2d , dataarray_result , expected_result , verify_dtype = False , verify_attrs = False
334+ )
335+
336+
206337@pytest .mark .parametrize ("backend" , ['numpy' , 'dask+numpy' ])
207338def test_count_crosstab_2d (backend , data_zones , data_values_2d , result_count_crosstab_2d ):
208339 zone_ids , cat_ids , expected_result = result_count_crosstab_2d
0 commit comments