2424from  jax ._src .lib  import  xla_client 
2525from  jax ._src .lib  import  xla_extension_version 
2626from  jax ._src .typing  import  Array 
27+ from  jax ._src .api  import  device_put 
2728
29+ DLPACK_VERSION  =  (0 , 1 )
30+ MIN_DLPACK_VERSION  =  (0 , 1 )
2831
2932# A set of dtypes that dlpack supports. 
3033# Note: Make sure to use a "type", not a dtype instance, when looking up this set 
@@ -48,9 +51,32 @@ class DLDeviceType(enum.IntEnum):
4851 kDLCUDA  =  2 
4952 kDLROCM  =  10 
5053
54+ def  _to_dlpack (x : Array , stream : int  |  Any  |  None ,
55+  device : xla_client .Device  |  None  =  None ,
56+  dlpack_device : xla_client .Device  |  None  =  None ,
57+  copy : bool  |  None  =  None ):
58+  if  dlpack_device  and  dlpack_device  !=  device :
59+  if  copy  is  not   None  and  not  copy :
60+  raise  ValueError (
61+  f"Specified { dlpack_device = }   which requires a copy since the source device " 
62+  f"is { repr (device )}  , however copy=False. Set copy=True or " 
63+  "copy=None to perform the requested operation." 
64+  )
65+  else :
66+  arr  =  device_put (x , dlpack_device )
67+  else :
68+  arr  =  x .copy () if  copy  else  x 
69+ 
70+  return  xla_client ._xla .buffer_to_dlpack_managed_tensor (
71+  arr .addressable_data (0 ), stream = stream 
72+  )
5173
5274def  to_dlpack (x : Array , take_ownership : bool  =  False ,
53-  stream : int  |  Any  |  None  =  None ):
75+  stream : int  |  Any  |  None  =  None ,
76+  device : xla_client .Device  |  None  =  None ,
77+  dl_device : tuple [DLDeviceType , int ] |  None  =  None ,
78+  max_version : tuple [int , int ] |  None  =  None ,
79+  copy  : bool  |  None  =  None ):
5480 """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``. 
5581
5682 Args: 
@@ -73,14 +99,40 @@ def to_dlpack(x: Array, take_ownership: bool = False,
7399 if  not  isinstance (x , array .ArrayImpl ):
74100 raise  TypeError ("Argument to to_dlpack must be a jax.Array, " 
75101 f"got { type (x )}  " )
76-  assert  len (x .devices ()) ==  1 
77102 if  take_ownership :
78103 warnings .warn (
79104 "take_ownership in to_dlpack is deprecated and it is a no-op." 
80105 )
81-  return  xla_client ._xla .buffer_to_dlpack_managed_tensor (
82-  x .addressable_data (0 ), stream = stream 
83-  ) # type: ignore 
106+ 
107+  dlpack_device  =  None 
108+  dl_device_type , local_hardware_id  =  dl_device  if  dl_device  else  (None , None )
109+  if  dl_device_type :
110+  try :
111+  dl_device_platform  =  {
112+  DLDeviceType .kDLCPU : "cpu" ,
113+  DLDeviceType .kDLCUDA : "cuda" ,
114+  DLDeviceType .kDLROCM : "rocm" ,
115+  }[dl_device_type ]
116+  backend  =  xla_bridge .get_backend (dl_device_platform )
117+  dlpack_device  =  backend .device_from_local_hardware_id (local_hardware_id )
118+  except  TypeError :
119+  # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html 
120+  # recommends using BufferError. 
121+  raise  BufferError (
122+  "The device specification passed to to_dlpack contains an unsupported " 
123+  f"device type (DLDeviceType: { dl_device_type }  )" )
124+ 
125+  if  max_version  is  None  or  max_version [0 ] >=  DLPACK_VERSION [0 ]:
126+  return  _to_dlpack (x , stream = stream , device = device , dlpack_device = dlpack_device , copy = copy )
127+  elif  max_version  >=  MIN_DLPACK_VERSION :
128+  # Legacy path to be implemented when XLA adopts DLManagedTensorVersioned format 
129+  raise  RuntimeError ("This branch should be unreachable. " 
130+  "Please open a bug if you see this." )
131+  else :
132+  raise  BufferError (
133+  f"JAX does not support any version below { MIN_DLPACK_VERSION }   but " 
134+  f"version ({ max_version }  ) was requested." 
135+  )
84136
85137
86138def  from_dlpack (external_array ):
@@ -110,12 +162,12 @@ def from_dlpack(external_array):
110162 DLDeviceType .kDLCUDA : "cuda" ,
111163 DLDeviceType .kDLROCM : "rocm" ,
112164 }[dl_device_type ]
113-  except  TypeError :
165+  except  TypeError   as   err :
114166 # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using 
115167 # TypeError. 
116-  raise  TypeError (
168+  raise  BufferError (
117169 "Array passed to from_dlpack is on unsupported device type " 
118-  f"(DLDeviceType: { dl_device_type }  , array: { external_array }  " )
170+  f"(DLDeviceType: { dl_device_type }  , array: { external_array }  " )  from   err 
119171
120172 backend  =  xla_bridge .get_backend (device_platform )
121173 device  =  backend .device_from_local_hardware_id (device_id )
0 commit comments