@@ -343,49 +343,57 @@ def test_paired_distances_callable():
343343def test_pairwise_distances_argmin_min ():
344344 # Check pairwise minimum distances computation for any metric
345345 X = [[0 ], [1 ]]
346- Y = [[- 1 ], [2 ]]
346+ Y = [[- 2 ], [3 ]]
347347
348348 Xsp = dok_matrix (X )
349349 Ysp = csr_matrix (Y , dtype = np .float32 )
350350
351- # euclidean metric
352- D , E = pairwise_distances_argmin_min (X , Y , metric = "euclidean" )
353- D2 = pairwise_distances_argmin (X , Y , metric = "euclidean" )
354- assert_array_almost_equal (D , [0 , 1 ])
355- assert_array_almost_equal (D2 , [0 , 1 ])
356- assert_array_almost_equal (D , [0 , 1 ])
357- assert_array_almost_equal (E , [1. , 1. ])
351+ expected_idx = [0 , 1 ]
352+ expected_vals = [2 , 2 ]
353+ expected_vals_sq = [4 , 4 ]
358354
355+ # euclidean metric
356+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "euclidean" )
357+ idx2 = pairwise_distances_argmin (X , Y , metric = "euclidean" )
358+ assert_array_almost_equal (idx , expected_idx )
359+ assert_array_almost_equal (idx2 , expected_idx )
360+ assert_array_almost_equal (vals , expected_vals )
359361 # sparse matrix case
360- Dsp , Esp = pairwise_distances_argmin_min (Xsp , Ysp , metric = "euclidean" )
361- assert_array_equal ( Dsp , D )
362- assert_array_equal ( Esp , E )
362+ idxsp , valssp = pairwise_distances_argmin_min (Xsp , Ysp , metric = "euclidean" )
363+ assert_array_almost_equal ( idxsp , expected_idx )
364+ assert_array_almost_equal ( valssp , expected_vals )
363365 # We don't want np.matrix here
364- assert_equal (type (Dsp ), np .ndarray )
365- assert_equal (type (Esp ), np .ndarray )
366+ assert_equal (type (idxsp ), np .ndarray )
367+ assert_equal (type (valssp ), np .ndarray )
368+
369+ # euclidean metric squared
370+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "euclidean" ,
371+ metric_kwargs = {"squared" : True })
372+ assert_array_almost_equal (idx , expected_idx )
373+ assert_array_almost_equal (vals , expected_vals_sq )
366374
367375 # Non-euclidean scikit-learn metric
368- D , E = pairwise_distances_argmin_min (X , Y , metric = "manhattan" )
369- D2 = pairwise_distances_argmin (X , Y , metric = "manhattan" )
370- assert_array_almost_equal (D , [ 0 , 1 ] )
371- assert_array_almost_equal (D2 , [ 0 , 1 ] )
372- assert_array_almost_equal (E , [ 1. , 1. ] )
373- D , E = pairwise_distances_argmin_min ( Xsp , Ysp , metric = "manhattan" )
374- D2 = pairwise_distances_argmin (Xsp , Ysp , metric = "manhattan" )
375- assert_array_almost_equal (D , [ 0 , 1 ] )
376- assert_array_almost_equal (E , [ 1. , 1. ] )
376+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "manhattan" )
377+ idx2 = pairwise_distances_argmin (X , Y , metric = "manhattan" )
378+ assert_array_almost_equal (idx , expected_idx )
379+ assert_array_almost_equal (idx2 , expected_idx )
380+ assert_array_almost_equal (vals , expected_vals )
381+ # sparse matrix case
382+ idxsp , valssp = pairwise_distances_argmin_min (Xsp , Ysp , metric = "manhattan" )
383+ assert_array_almost_equal (idxsp , expected_idx )
384+ assert_array_almost_equal (valssp , expected_vals )
377385
378386 # Non-euclidean Scipy distance (callable)
379- D , E = pairwise_distances_argmin_min (X , Y , metric = minkowski ,
380- metric_kwargs = {"p" : 2 })
381- assert_array_almost_equal (D , [ 0 , 1 ] )
382- assert_array_almost_equal (E , [ 1. , 1. ] )
387+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = minkowski ,
388+ metric_kwargs = {"p" : 2 })
389+ assert_array_almost_equal (idx , expected_idx )
390+ assert_array_almost_equal (vals , expected_vals )
383391
384392 # Non-euclidean Scipy distance (string)
385- D , E = pairwise_distances_argmin_min (X , Y , metric = "minkowski" ,
386- metric_kwargs = {"p" : 2 })
387- assert_array_almost_equal (D , [ 0 , 1 ] )
388- assert_array_almost_equal (E , [ 1. , 1. ] )
393+ idx , vals = pairwise_distances_argmin_min (X , Y , metric = "minkowski" ,
394+ metric_kwargs = {"p" : 2 })
395+ assert_array_almost_equal (idx , expected_idx )
396+ assert_array_almost_equal (vals , expected_vals )
389397
390398 # Compare with naive implementation
391399 rng = np .random .RandomState (0 )
0 commit comments