11from  __future__ import  annotations 
22
3- from  ...common  import  _aliases 
3+ from  typing  import  Callable 
4+ 
5+ from  ...common  import  _aliases , array_namespace 
46
57from  ..._internal  import  get_xp 
68
2931)
3032
3133from  typing  import  TYPE_CHECKING 
34+ 
3235if  TYPE_CHECKING :
3336 from  typing  import  Optional , Union 
3437
35-  from  ...common ._typing  import  Device , Dtype , Array , NestedSequence , SupportsBufferProtocol 
38+  from  ...common ._typing  import  (
39+  Device ,
40+  Dtype ,
41+  Array ,
42+  NestedSequence ,
43+  SupportsBufferProtocol ,
44+  )
3645
3746import  dask .array  as  da 
3847
3948isdtype  =  get_xp (np )(_aliases .isdtype )
4049unstack  =  get_xp (da )(_aliases .unstack )
4150
51+ 
4252# da.astype doesn't respect copy=True 
4353def  astype (
4454 x : Array ,
4555 dtype : Dtype ,
4656 / ,
4757 * ,
4858 copy : bool  =  True ,
49-  device : Optional [Device ] =  None 
59+  device : Optional [Device ] =  None , 
5060) ->  Array :
5161 """ 
5262 Array API compatibility wrapper for astype(). 
@@ -61,8 +71,10 @@ def astype(
6171 x  =  x .astype (dtype )
6272 return  x .copy () if  copy  else  x 
6373
74+ 
6475# Common aliases 
6576
77+ 
6678# This arange func is modified from the common one to 
6779# not pass stop/step as keyword arguments, which will cause 
6880# an error with dask 
@@ -189,6 +201,7 @@ def asarray(
189201 concatenate  as  concat ,
190202)
191203
204+ 
192205# dask.array.clip does not work unless all three arguments are provided. 
193206# Furthermore, the masking workaround in common._aliases.clip cannot work with 
194207# dask (meaning uint64 promoting to float64 is going to just be unfixed for 
@@ -205,8 +218,10 @@ def clip(
205218 See the corresponding documentation in the array library and/or the array API 
206219 specification for more details. 
207220 """ 
221+ 
208222 def  _isscalar (a ):
209223 return  isinstance (a , (int , float , type (None )))
224+ 
210225 min_shape  =  () if  _isscalar (min ) else  min .shape 
211226 max_shape  =  () if  _isscalar (max ) else  max .shape 
212227
@@ -228,12 +243,99 @@ def _isscalar(a):
228243
229244 return  astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
230245
231- # exclude these from all since dask.array has no sorting functions 
232- _da_unsupported  =  ['sort' , 'argsort' ]
233246
234- _common_aliases  =  [alias  for  alias  in  _aliases .__all__  if  alias  not  in _da_unsupported ]
247+ def  _ensure_single_chunk (x : Array , axis : int ) ->  tuple [Array , Callable [[Array ], Array ]]:
248+  """ 
249+  Make sure that Array is not broken into multiple chunks along axis. 
250+ 
251+  Returns 
252+  ------- 
253+  x : Array 
254+  The input Array with a single chunk along axis. 
255+  restore : Callable[Array, Array] 
256+  function to apply to the output to rechunk it back into reasonable chunks 
257+  """ 
258+  if  axis  <  0 :
259+  axis  +=  x .ndim 
260+  if  x .numblocks [axis ] <  2 :
261+  return  x , lambda  x : x 
262+ 
263+  # Break chunks on other axes in an attempt to keep chunk size low 
264+  x  =  x .rechunk ({i : - 1  if  i  ==  axis  else  "auto"  for  i  in  range (x .ndim )})
265+ 
266+  # Rather than reconstructing the original chunks, which can be a 
267+  # very expensive affair, just break down oversized chunks without 
268+  # incurring in any transfers over the network. 
269+  # This has the downside of a risk of overchunking if the array is 
270+  # then used in operations against other arrays that match the 
271+  # original chunking pattern. 
272+  return  x , lambda  x : x .rechunk ()
273+ 
274+ 
275+ def  sort (
276+  x : Array , / , * , axis : int  =  - 1 , descending : bool  =  False , stable : bool  =  True 
277+ ) ->  Array :
278+  """ 
279+  Array API compatibility layer around the lack of sort() in Dask. 
280+ 
281+  Warnings 
282+  -------- 
283+  This function temporarily rechunks the array along `axis` to a single chunk. 
284+  This can be extremely inefficient and can lead to out-of-memory errors. 
285+ 
286+  See the corresponding documentation in the array library and/or the array API 
287+  specification for more details. 
288+  """ 
289+  x , restore  =  _ensure_single_chunk (x , axis )
290+ 
291+  meta_xp  =  array_namespace (x ._meta )
292+  x  =  da .map_blocks (
293+  meta_xp .sort ,
294+  x ,
295+  axis = axis ,
296+  meta = x ._meta ,
297+  dtype = x .dtype ,
298+  descending = descending ,
299+  stable = stable ,
300+  )
301+ 
302+  return  restore (x )
235303
236- __all__  =  _common_aliases  +  ['__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
304+ 
305+ def  argsort (
306+  x : Array , / , * , axis : int  =  - 1 , descending : bool  =  False , stable : bool  =  True 
307+ ) ->  Array :
308+  """ 
309+  Array API compatibility layer around the lack of argsort() in Dask. 
310+ 
311+  See the corresponding documentation in the array library and/or the array API 
312+  specification for more details. 
313+ 
314+  Warnings 
315+  -------- 
316+  This function temporarily rechunks the array along `axis` into a single chunk. 
317+  This can be extremely inefficient and can lead to out-of-memory errors. 
318+  """ 
319+  x , restore  =  _ensure_single_chunk (x , axis )
320+ 
321+  meta_xp  =  array_namespace (x ._meta )
322+  dtype  =  meta_xp .argsort (x ._meta ).dtype 
323+  meta  =  meta_xp .astype (x ._meta , dtype )
324+  x  =  da .map_blocks (
325+  meta_xp .argsort ,
326+  x ,
327+  axis = axis ,
328+  meta = meta ,
329+  dtype = dtype ,
330+  descending = descending ,
331+  stable = stable ,
332+  )
333+ 
334+  return  restore (x )
335+ 
336+ 
337+ __all__  =  _aliases .__all__  +  [
338+  '__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
237339 'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' ,
238340 'atanh' , 'bitwise_left_shift' , 'bitwise_invert' ,
239341 'bitwise_right_shift' , 'concat' , 'pow' , 'iinfo' , 'finfo' , 'can_cast' ,
@@ -242,4 +344,4 @@ def _isscalar(a):
242344 'complex64' , 'complex128' , 'iinfo' , 'finfo' ,
243345 'can_cast' , 'result_type' ]
244346
245- _all_ignore  =  ["get_xp" , "da" , "np" ]
347+ _all_ignore  =  ["Callable"  ,  "array_namespace" ,  " get_xp""da" , "np" ]
0 commit comments