2121from jax import numpy as jnp
2222from jax ._src import array
2323from jax ._src import xla_bridge
24+ from jax ._src .lax .lax import _array_copy
2425from jax ._src .lib import xla_client
2526from jax ._src .lib import xla_extension_version
2627from jax ._src .typing import Array
28+ from jax ._src .api import device_put
2729
30+ DLPACK_VERSION = (0 , 8 )
31+ MIN_DLPACK_VERSION = (0 , 5 )
2832
2933# A set of dtypes that dlpack supports.
3034# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +52,34 @@ class DLDeviceType(enum.IntEnum):
4852 kDLCUDA = 2
4953 kDLROCM = 10
5054
55+ def _to_dlpack (x : Array , stream : int | Any | None ,
56+ src_device : xla_client .Device | None = None ,
57+ device : xla_client .Device | None = None ,
58+ copy : bool | None = None ):
59+
60+ if src_device is None :
61+ src_device , = x .devices ()
62+ if device and (src_device is None or device != src_device ):
63+ if copy is not None and not copy :
64+ raise ValueError (
65+ f"Specified { device = } which requires a copy since the source device "
66+ f"is { repr (src_device )} , however copy=False. Set copy=True or "
67+ "copy=None to perform the requested operation."
68+ )
69+ else :
70+ arr = device_put (x , device )
71+ else :
72+ arr = _array_copy (x ) if copy else x
73+ return xla_client ._xla .buffer_to_dlpack_managed_tensor (
74+ arr .addressable_data (0 ), stream = stream
75+ )
5176
5277def to_dlpack (x : Array , take_ownership : bool = False ,
53- stream : int | Any | None = None ):
78+ stream : int | Any | None = None ,
79+ src_device : xla_client .Device | None = None ,
80+ dl_device : tuple [DLDeviceType , int ] | None = None ,
81+ max_version : tuple [int , int ] | None = None ,
82+ copy : bool | None = None ):
5483 """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
5584
5685 Args:
@@ -60,27 +89,97 @@ def to_dlpack(x: Array, take_ownership: bool = False,
6089 stream: optional platform-dependent stream to wait on until the buffer is
6190 ready. This corresponds to the `stream` argument to ``__dlpack__``
6291 documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
92+ src_device: either a CPU or GPU :class:`~jax.Device`.
93+ dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
94+ format e.g. as produced by ``__dlpack_device__``.
95+ max_version: the maximum DLPack version that the consumer (i.e. caller of
96+ ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
97+ This function is not guaranteed to return a capsule of version
98+ ``max_version``.
99+ copy: a boolean indicating whether or not to copy the input. If
100+ ``copy=True`` then the function must always copy. When
101+ ``copy=False`` then the function must never copy, and must raise an error
102+ when a copy is deemed necessary. If ``copy=None`` then the function must
103+ avoid a copy if possible but may copy if needed.
63104
64105 Returns:
65- A dlpack PyCapsule object.
106+ A DLPack PyCapsule object.
66107
67108 Note:
68- While JAX arrays are always immutable, dlpack buffers cannot be marked as
69- immutable, and it is possible for processes external to JAX to mutate them
70- in-place. If a dlpack buffer derived from a JAX array is mutated, it may
71- lead to undefined behavior when using the associated JAX array.
109+ While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
110+ cannot be marked as immutable, and it is possible for processes external
111+ to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
112+ is mutated, it may lead to undefined behavior when using the associated JAX
113+ array. When JAX eventually supports ``DLManagedTensorVersioned``
114+ (DLPack 1.0), it will be possible to specify that a buffer is read-only.
72115 """
73116 if not isinstance (x , array .ArrayImpl ):
74117 raise TypeError ("Argument to to_dlpack must be a jax.Array, "
75118 f"got { type (x )} " )
76- assert len (x .devices ()) == 1
77119 if take_ownership :
78120 warnings .warn (
79121 "take_ownership in to_dlpack is deprecated and it is a no-op."
80122 )
81- return xla_client ._xla .buffer_to_dlpack_managed_tensor (
82- x .addressable_data (0 ), stream = stream
83- ) # type: ignore
123+
124+ device = None
125+ dl_device_type , local_hardware_id = dl_device if dl_device else (None , None )
126+ if dl_device_type :
127+ try :
128+ dl_device_platform = {
129+ DLDeviceType .kDLCPU : "cpu" ,
130+ DLDeviceType .kDLCUDA : "cuda" ,
131+ DLDeviceType .kDLROCM : "rocm" ,
132+ }[dl_device_type ]
133+ backend = xla_bridge .get_backend (dl_device_platform )
134+ device = backend .device_from_local_hardware_id (local_hardware_id )
135+ except TypeError :
136+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
137+ # recommends using BufferError.
138+ raise BufferError (
139+ "The device specification passed to to_dlpack contains an unsupported "
140+ f"device type (DLDeviceType: { dl_device_type } )" )
141+
142+ # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
143+ # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
144+ # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0)
145+ if max_version is None :
146+ # Backwards compatible default
147+ return _to_dlpack (
148+ x , stream = stream ,
149+ src_device = src_device ,
150+ device = device ,
151+ copy = copy
152+ )
153+ else :
154+ if max_version >= DLPACK_VERSION :
155+ # Latest
156+ return _to_dlpack (
157+ x , stream = stream ,
158+ src_device = src_device ,
159+ device = device ,
160+ copy = copy
161+ )
162+ if max_version [0 ] == DLPACK_VERSION [0 ]:
163+ # ABI compatible
164+ return _to_dlpack (
165+ x , stream = stream ,
166+ src_device = src_device ,
167+ device = device ,
168+ copy = copy
169+ )
170+ elif max_version >= MIN_DLPACK_VERSION :
171+ # Oldest supported
172+ return _to_dlpack (
173+ x , stream = stream ,
174+ src_device = src_device ,
175+ device = device ,
176+ copy = copy
177+ )
178+ else :
179+ raise BufferError (
180+ f"JAX does not support any version below { MIN_DLPACK_VERSION } but "
181+ f"version ({ max_version } ) was requested."
182+ )
84183
85184
86185def from_dlpack (external_array ):
@@ -110,12 +209,12 @@ def from_dlpack(external_array):
110209 DLDeviceType .kDLCUDA : "cuda" ,
111210 DLDeviceType .kDLROCM : "rocm" ,
112211 }[dl_device_type ]
113- except TypeError :
212+ except TypeError as err :
114213 # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115214 # TypeError.
116- raise TypeError (
215+ raise BufferError (
117216 "Array passed to from_dlpack is on unsupported device type "
118- f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
217+ f"(DLDeviceType: { dl_device_type } , array: { external_array } " ) from err
119218
120219 backend = xla_bridge .get_backend (device_platform )
121220 device = backend .device_from_local_hardware_id (device_id )
0 commit comments