|
| 1 | +use ndarray::arr2; |
| 2 | +use ndarray::*; |
| 3 | +use ndarray_linalg::rank::Rank; |
| 4 | +use ndarray_linalg::*; |
| 5 | +use rand::{seq::SliceRandom, thread_rng}; |
| 6 | + |
| 7 | +/// creates a zero matrix which always has rank zero |
| 8 | +pub fn zero_rank<A, S, Sh, D>(sh: Sh) -> ArrayBase<S, D> |
| 9 | +where |
| 10 | + A: Scalar, |
| 11 | + S: DataOwned<Elem = A>, |
| 12 | + D: Dimension, |
| 13 | + Sh: ShapeBuilder<Dim = D>, |
| 14 | +{ |
| 15 | + ArrayBase::zeros(sh) |
| 16 | +} |
| 17 | + |
| 18 | +/// creates a random matrix and repeatedly creates a linear dependency between rows until the |
| 19 | +/// rank drops. |
| 20 | +pub fn partial_rank<A, Sh>(sh: Sh) -> Array2<A> |
| 21 | +where |
| 22 | + A: Scalar + Lapack, |
| 23 | + Sh: ShapeBuilder<Dim = Ix2>, |
| 24 | +{ |
| 25 | + let mut rng = thread_rng(); |
| 26 | + let mut result: Array2<A> = random(sh); |
| 27 | + println!("before: {:?}", result); |
| 28 | + |
| 29 | + let (n, m) = result.dim(); |
| 30 | + println!("(n, m) => ({:?},{:?})", n, m); |
| 31 | + |
| 32 | + // create randomized row iterator |
| 33 | + let min_dim = n.min(m); |
| 34 | + let mut row_indexes = (0..min_dim).into_iter().collect::<Vec<usize>>(); |
| 35 | + row_indexes.as_mut_slice().shuffle(&mut rng); |
| 36 | + let mut row_index_iter = row_indexes.iter().cycle(); |
| 37 | + |
| 38 | + for count in 1..=10 { |
| 39 | + println!("count: {}", count); |
| 40 | + let (&x, &y) = ( |
| 41 | + row_index_iter.next().unwrap(), |
| 42 | + row_index_iter.next().unwrap(), |
| 43 | + ); |
| 44 | + let (from_row_index, to_row_index) = if x < y { (x, y) } else { (y, x) }; |
| 45 | + println!("(r_f, r_t) => ({:?},{:?})", from_row_index, to_row_index); |
| 46 | + |
| 47 | + let mut it = result.outer_iter_mut(); |
| 48 | + let from_row = it.nth(from_row_index).unwrap(); |
| 49 | + let mut to_row = it.nth(to_row_index - (from_row_index + 1)).unwrap(); |
| 50 | + |
| 51 | + // set the to_row with the value of the from_row multiplied by rand_multiple |
| 52 | + let rand_multiple = A::rand(&mut rng); |
| 53 | + println!("rand_multiple: {:?}", rand_multiple); |
| 54 | + Zip::from(&mut to_row) |
| 55 | + .and(&from_row) |
| 56 | + .for_each(|r1, r2| *r1 = *r2 * rand_multiple); |
| 57 | + |
| 58 | + if let Ok(rank) = result.rank() { |
| 59 | + println!("result: {:?}", result); |
| 60 | + println!("rank: {:?}", rank); |
| 61 | + if rank > 0 && rank < min_dim { |
| 62 | + return result; |
| 63 | + } |
| 64 | + } |
| 65 | + } |
| 66 | + unreachable!("unable to generate random partial rank matrix after making 10 mutations") |
| 67 | +} |
| 68 | + |
| 69 | +/// creates a random matrix and insures it is full rank. |
| 70 | +pub fn full_rank<A, Sh>(sh: Sh) -> Array2<A> |
| 71 | +where |
| 72 | + A: Scalar + Lapack, |
| 73 | + Sh: ShapeBuilder<Dim = Ix2> + Clone, |
| 74 | +{ |
| 75 | + for _ in 0..10 { |
| 76 | + let r: Array2<A> = random(sh.clone()); |
| 77 | + let (n, m) = r.dim(); |
| 78 | + let n = n.min(m); |
| 79 | + if let Ok(rank) = r.rank() { |
| 80 | + println!("result: {:?}", r); |
| 81 | + println!("rank: {:?}", rank); |
| 82 | + if rank == n { |
| 83 | + return r; |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + unreachable!("unable to generate random full rank matrix in 10 tries") |
| 88 | +} |
| 89 | + |
| 90 | +fn test<T: Scalar + Lapack>(a: &Array2<T>, tolerance: T::Real) { |
| 91 | + println!("a = \n{:?}", &a); |
| 92 | + let a_plus: Array2<_> = a.pinv(None).unwrap(); |
| 93 | + println!("a_plus = \n{:?}", &a_plus); |
| 94 | + let ident = a.dot(&a_plus); |
| 95 | + assert_close_l2!(&ident.dot(a), &a, tolerance); |
| 96 | + assert_close_l2!(&a_plus.dot(&ident), &a_plus, tolerance); |
| 97 | +} |
| 98 | + |
| 99 | +macro_rules! test_both_impl { |
| 100 | + ($type:ty, $test:tt, $n:expr, $m:expr, $t:expr) => { |
| 101 | + paste::item! { |
| 102 | + #[test] |
| 103 | + fn [<pinv_test_ $type _ $test _ $n x $m _r>]() { |
| 104 | + let a: Array2<$type> = $test(($n, $m)); |
| 105 | + test::<$type>(&a, $t); |
| 106 | + } |
| 107 | + |
| 108 | + #[test] |
| 109 | + fn [<pinv_test_ $type _ $test _ $n x $m _c>]() { |
| 110 | + let a = $test(($n, $m).f()); |
| 111 | + test::<$type>(&a, $t); |
| 112 | + } |
| 113 | + } |
| 114 | + }; |
| 115 | +} |
| 116 | + |
| 117 | +macro_rules! test_pinv_impl { |
| 118 | + ($type:ty, $n:expr, $m:expr, $a:expr) => { |
| 119 | + test_both_impl!($type, zero_rank, $n, $m, $a); |
| 120 | + test_both_impl!($type, partial_rank, $n, $m, $a); |
| 121 | + test_both_impl!($type, full_rank, $n, $m, $a); |
| 122 | + }; |
| 123 | +} |
| 124 | + |
| 125 | +test_pinv_impl!(f32, 3, 3, 1e-4); |
| 126 | +test_pinv_impl!(f32, 4, 3, 1e-4); |
| 127 | +test_pinv_impl!(f32, 3, 4, 1e-4); |
| 128 | + |
| 129 | +test_pinv_impl!(c32, 3, 3, 1e-4); |
| 130 | +test_pinv_impl!(c32, 4, 3, 1e-4); |
| 131 | +test_pinv_impl!(c32, 3, 4, 1e-4); |
| 132 | + |
| 133 | +test_pinv_impl!(f64, 3, 3, 1e-12); |
| 134 | +test_pinv_impl!(f64, 4, 3, 1e-12); |
| 135 | +test_pinv_impl!(f64, 3, 4, 1e-12); |
| 136 | + |
| 137 | +test_pinv_impl!(c64, 3, 3, 1e-12); |
| 138 | +test_pinv_impl!(c64, 4, 3, 1e-12); |
| 139 | +test_pinv_impl!(c64, 3, 4, 1e-12); |
| 140 | + |
| 141 | +// |
| 142 | +// This matrix was taken from 7.1.1 Test1 in |
| 143 | +// "On Moore-Penrose Pseudoinverse Computation for Stiffness Matrices Resulting |
| 144 | +// from Higher Order Approximation" by Marek Klimczak |
| 145 | +// https://doi.org/10.1155/2019/5060397 |
| 146 | +// |
| 147 | +#[test] |
| 148 | +fn pinv_test_single_value_less_then_threshold_3x3() { |
| 149 | + #[rustfmt::skip] |
| 150 | + let a: Array2<f64> = arr2(&[ |
| 151 | + [ 1., -1., 0.], |
| 152 | + [-1., 2., -1.], |
| 153 | + [ 0., -1., 1.] |
| 154 | + ], |
| 155 | + ); |
| 156 | + #[rustfmt::skip] |
| 157 | + let a_plus_actual: Array2<f64> = arr2(&[ |
| 158 | + [ 5. / 9., -1. / 9., -4. / 9.], |
| 159 | + [-1. / 9., 2. / 9., -1. / 9.], |
| 160 | + [-4. / 9., -1. / 9., 5. / 9.], |
| 161 | + ], |
| 162 | + ); |
| 163 | + let a_plus: Array2<_> = a.pinv(None).unwrap(); |
| 164 | + println!("a_plus -> {:?}", &a_plus); |
| 165 | + println!("a_plus_actual -> {:?}", &a_plus); |
| 166 | + assert_close_l2!(&a_plus, &a_plus_actual, 1e-15); |
| 167 | +} |
0 commit comments