2424from jax ._src .lib import xla_client
2525from jax ._src .lib import xla_extension_version
2626from jax ._src .typing import Array
27+ from jax ._src .api import device_put
2728
29+ DLPACK_VERSION = (0 , 1 )
30+ MIN_DLPACK_VERSION = (0 , 1 )
2831
2932# A set of dtypes that dlpack supports.
3033# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -48,9 +51,33 @@ class DLDeviceType(enum.IntEnum):
4851 kDLCUDA = 2
4952 kDLROCM = 10
5053
54+ def _to_dlpack (x : Array , stream : int | Any | None ,
55+ device : xla_client .Device | None = None ,
56+ dlpack_device : xla_client .Device | None = None ,
57+ copy : bool | None = None ):
58+ arr = None
59+ if dlpack_device and dlpack_device != device :
60+ if copy is not None and not copy :
61+ raise ValueError (
62+ f"Specified { dlpack_device = } which requires a copy since the source device "
63+ f"is { repr (device )} , however copy=False. Set copy=True or "
64+ "copy=None to perform the requested operation."
65+ )
66+ else :
67+ arr = device_put (x , dlpack_device )
68+ if arr is None :
69+ arr = x .copy () if copy else x
70+
71+ return xla_client ._xla .buffer_to_dlpack_managed_tensor (
72+ arr .addressable_data (0 ), stream = stream
73+ ) # type: ignore
5174
5275def to_dlpack (x : Array , take_ownership : bool = False ,
53- stream : int | Any | None = None ):
76+ stream : int | Any | None = None ,
77+ device : xla_client .Device | None = None ,
78+ dl_device : tuple [DLDeviceType , int ] | None = None ,
79+ max_version : tuple [int , int ] | None = None ,
80+ copy : bool | None = None ):
5481 """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
5582
5683 Args:
@@ -73,14 +100,40 @@ def to_dlpack(x: Array, take_ownership: bool = False,
73100 if not isinstance (x , array .ArrayImpl ):
74101 raise TypeError ("Argument to to_dlpack must be a jax.Array, "
75102 f"got { type (x )} " )
76- assert len (x .devices ()) == 1
77103 if take_ownership :
78104 warnings .warn (
79105 "take_ownership in to_dlpack is deprecated and it is a no-op."
80106 )
81- return xla_client ._xla .buffer_to_dlpack_managed_tensor (
82- x .addressable_data (0 ), stream = stream
83- ) # type: ignore
107+
108+ dlpack_device = None
109+ dl_device_type , local_hardware_id = dl_device if dl_device else (None , None )
110+ if dl_device_type :
111+ try :
112+ dl_device_platform = {
113+ DLDeviceType .kDLCPU : "cpu" ,
114+ DLDeviceType .kDLCUDA : "cuda" ,
115+ DLDeviceType .kDLROCM : "rocm" ,
116+ }[dl_device_type ]
117+ backend = xla_bridge .get_backend (dl_device_platform )
118+ dlpack_device = backend .device_from_local_hardware_id (local_hardware_id )
119+ except TypeError :
120+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
121+ # recommends using BufferError.
122+ raise BufferError (
123+ "The device specification passed to to_dlpack contains an unsupported "
124+ f"device type (DLDeviceType: { dl_device_type } )" )
125+
126+ if max_version is None or max_version [0 ] >= DLPACK_VERSION [0 ]:
127+ return _to_dlpack (x , stream = stream , device = device , dlpack_device = dlpack_device , copy = copy )
128+ elif max_version >= MIN_DLPACK_VERSION :
129+ # Legacy path to be implemented when XLA adopts DLManagedTensorVersioned format
130+ raise RuntimeError ("This branch should be unreachable. "
131+ "Please open a bug if you see this." )
132+ else :
133+ raise BufferError (
134+ f"JAX does not support any version below { MIN_DLPACK_VERSION } but "
135+ f"version ({ max_version } ) was requested."
136+ )
84137
85138
86139def from_dlpack (external_array ):
0 commit comments