8686# 
8787# License: MIT License 
8888
89- import  numpy  as  np 
9089import  os 
91- import  scipy 
92- import  scipy .linalg 
93- from  scipy .sparse  import  issparse , coo_matrix , csr_matrix 
94- import  scipy .special  as  special 
9590import  time 
9691import  warnings 
9792
93+ import  numpy  as  np 
94+ import  scipy 
95+ import  scipy .linalg 
96+ import  scipy .special  as  special 
97+ from  scipy .sparse  import  coo_matrix , csr_matrix , issparse 
9898
9999DISABLE_TORCH_KEY  =  'POT_BACKEND_DISABLE_PYTORCH' 
100100DISABLE_JAX_KEY  =  'POT_BACKEND_DISABLE_JAX' 
@@ -650,7 +650,7 @@ def std(self, a, axis=None):
650650 """ 
651651 raise  NotImplementedError ()
652652
653-  def  linspace (self , start , stop , num ):
653+  def  linspace (self , start , stop , num ,  type_as = None ):
654654 r""" 
655655 Returns a specified number of evenly spaced values over a given interval. 
656656
@@ -1208,8 +1208,11 @@ def median(self, a, axis=None):
12081208 def  std (self , a , axis = None ):
12091209 return  np .std (a , axis = axis )
12101210
1211-  def  linspace (self , start , stop , num ):
1212-  return  np .linspace (start , stop , num )
1211+  def  linspace (self , start , stop , num , type_as = None ):
1212+  if  type_as  is  None :
1213+  return  np .linspace (start , stop , num )
1214+  else :
1215+  return  np .linspace (start , stop , num , dtype = type_as .dtype )
12131216
12141217 def  meshgrid (self , a , b ):
12151218 return  np .meshgrid (a , b )
@@ -1579,8 +1582,11 @@ def median(self, a, axis=None):
15791582 def  std (self , a , axis = None ):
15801583 return  jnp .std (a , axis = axis )
15811584
1582-  def  linspace (self , start , stop , num ):
1583-  return  jnp .linspace (start , stop , num )
1585+  def  linspace (self , start , stop , num , type_as = None ):
1586+  if  type_as  is  None :
1587+  return  jnp .linspace (start , stop , num )
1588+  else :
1589+  return  self ._change_device (jnp .linspace (start , stop , num , dtype = type_as .dtype ), type_as )
15841590
15851591 def  meshgrid (self , a , b ):
15861592 return  jnp .meshgrid (a , b )
@@ -1986,6 +1992,7 @@ def concatenate(self, arrays, axis=0):
19861992
19871993 def  zero_pad (self , a , pad_width , value = 0 ):
19881994 from  torch .nn .functional  import  pad 
1995+ 
19891996 # pad_width is an array of ndim tuples indicating how many 0 before and after 
19901997 # we need to add. We first need to make it compliant with torch syntax, that 
19911998 # starts with the last dim, then second last, etc. 
@@ -2006,6 +2013,7 @@ def mean(self, a, axis=None):
20062013
20072014 def  median (self , a , axis = None ):
20082015 from  packaging  import  version 
2016+ 
20092017 # Since version 1.11.0, interpolation is available 
20102018 if  version .parse (torch .__version__ ) >=  version .parse ("1.11.0" ):
20112019 if  axis  is  not   None :
@@ -2026,8 +2034,11 @@ def std(self, a, axis=None):
20262034 else :
20272035 return  torch .std (a , unbiased = False )
20282036
2029-  def  linspace (self , start , stop , num ):
2030-  return  torch .linspace (start , stop , num , dtype = torch .float64 )
2037+  def  linspace (self , start , stop , num , type_as = None ):
2038+  if  type_as  is  None :
2039+  return  torch .linspace (start , stop , num )
2040+  else :
2041+  return  torch .linspace (start , stop , num , dtype = type_as .dtype , device = type_as .device )
20312042
20322043 def  meshgrid (self , a , b ):
20332044 try :
@@ -2427,8 +2438,12 @@ def median(self, a, axis=None):
24272438 def  std (self , a , axis = None ):
24282439 return  cp .std (a , axis = axis )
24292440
2430-  def  linspace (self , start , stop , num ):
2431-  return  cp .linspace (start , stop , num )
2441+  def  linspace (self , start , stop , num , type_as = None ):
2442+  if  type_as  is  None :
2443+  return  cp .linspace (start , stop , num )
2444+  else :
2445+  with  cp .cuda .Device (type_as .device ):
2446+  return  cp .linspace (start , stop , num , dtype = type_as .dtype )
24322447
24332448 def  meshgrid (self , a , b ):
24342449 return  cp .meshgrid (a , b )
@@ -2834,8 +2849,11 @@ def median(self, a, axis=None):
28342849 def  std (self , a , axis = None ):
28352850 return  tnp .std (a , axis = axis )
28362851
2837-  def  linspace (self , start , stop , num ):
2838-  return  tnp .linspace (start , stop , num )
2852+  def  linspace (self , start , stop , num , type_as = None ):
2853+  if  type_as  is  None :
2854+  return  tnp .linspace (start , stop , num )
2855+  else :
2856+  return  tnp .linspace (start , stop , num , dtype = type_as .dtype )
28392857
28402858 def  meshgrid (self , a , b ):
28412859 return  tnp .meshgrid (a , b )
0 commit comments