Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ where
S: Data,
{
let n = a.len();
a.into_shape((n)).unwrap()
a.into_shape(n).unwrap()
}

pub fn into_matrix<A, S>(l: MatrixLayout, a: Vec<A>) -> Result<ArrayBase<S, Ix2>>
Expand Down
14 changes: 7 additions & 7 deletions src/diagonal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub trait IntoDiagonal<S: Data> {
}

pub trait AsDiagonal<A> {
fn as_diagonal<'a>(&'a self) -> Diagonal<ViewRepr<&'a A>>;
fn as_diagonal(&self) -> Diagonal<ViewRepr<&A>>;
}

impl<S: Data> IntoDiagonal<S> for ArrayBase<S, Ix1> {
Expand All @@ -25,7 +25,7 @@ impl<S: Data> IntoDiagonal<S> for ArrayBase<S, Ix1> {
}

impl<A, S: Data<Elem = A>> AsDiagonal<A> for ArrayBase<S, Ix1> {
fn as_diagonal<'a>(&'a self) -> Diagonal<ViewRepr<&'a A>> {
fn as_diagonal(&self) -> Diagonal<ViewRepr<&A>> {
Diagonal { diag: self.view() }
}
}
Expand All @@ -36,7 +36,7 @@ where
S: Data<Elem = A>,
Sr: DataMut<Elem = A>,
{
fn op_inplace<'a>(&self, mut a: &'a mut ArrayBase<Sr, Ix1>) -> &'a mut ArrayBase<Sr, Ix1> {
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix1>) -> &'a mut ArrayBase<Sr, Ix1> {
for (val, d) in a.iter_mut().zip(self.diag.iter()) {
*val = *val * *d;
}
Expand Down Expand Up @@ -64,7 +64,7 @@ where
Sr: DataOwned<Elem = A> + DataMut,
{
fn op_into(&self, mut a: ArrayBase<Sr, Ix1>) -> ArrayBase<Sr, Ix1> {
self.op(&mut a);
self.op_inplace(&mut a);
a
}
}
Expand All @@ -75,8 +75,8 @@ where
S: Data<Elem = A>,
Sr: DataMut<Elem = A>,
{
fn op_inplace<'a>(&self, mut a: &'a mut ArrayBase<Sr, Ix2>) -> &'a mut ArrayBase<Sr, Ix2> {
let ref d = self.diag;
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix2>) -> &'a mut ArrayBase<Sr, Ix2> {
let d = &self.diag;
for ((i, _), val) in a.indexed_iter_mut() {
*val = *val * d[i];
}
Expand Down Expand Up @@ -104,7 +104,7 @@ where
Sr: DataOwned<Elem = A> + DataMut,
{
fn op_into(&self, mut a: ArrayBase<Sr, Ix2>) -> ArrayBase<Sr, Ix2> {
self.op(&mut a);
self.op_inplace(&mut a);
a
}
}
2 changes: 1 addition & 1 deletion src/lapack_traits/cholesky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub trait Cholesky_: Sized {
macro_rules! impl_cholesky {
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
impl Cholesky_ for $scalar {
unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> {
unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, n);
into_result(info, ())
Expand Down
2 changes: 1 addition & 1 deletion src/lapack_traits/qr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl QR_ for $scalar {
into_result(info, ())
}

unsafe fn qr(l: MatrixLayout, mut a: &mut [Self]) -> Result<Vec<Self>> {
unsafe fn qr(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>> {
let tau = Self::householder(l, a)?;
let r = Vec::from(&*a);
Self::q(l, a, &tau)?;
Expand Down
9 changes: 4 additions & 5 deletions src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ impl MatrixLayout {

pub fn lda(&self) -> LDA {
match *self {
MatrixLayout::C((_, lda)) => lda,
MatrixLayout::F((_, lda)) => lda,
MatrixLayout::C((_, lda)) | MatrixLayout::F((_, lda)) => lda,
}
}

Expand Down Expand Up @@ -123,7 +122,7 @@ where
}

fn as_allocated(&self) -> Result<&[A]> {
Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?)
Ok(self.as_slice_memory_order().ok_or_else(MemoryContError::new)?)
}
}

Expand All @@ -132,8 +131,8 @@ where
S: DataMut<Elem = A>,
{
fn as_allocated_mut(&mut self) -> Result<&mut [A]> {
Ok(self.as_slice_memory_order_mut().ok_or(
MemoryContError::new(),
Ok(self.as_slice_memory_order_mut().ok_or_else(
MemoryContError::new,
)?)
}
}
2 changes: 1 addition & 1 deletion src/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ where
D: Dimension + RemoveAxis,
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
{
fn op_multi_inplace<'a>(&self, mut a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D> {
fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D> {
let n = a.ndim();
for mut col in a.axis_iter_mut(Axis(n - 1)) {
self.op_inplace(&mut col);
Expand Down
2 changes: 1 addition & 1 deletion src/opnorm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use lapack_traits::NormType;

/// Operator norm using `*lange` LAPACK routines
///
/// https://en.wikipedia.org/wiki/Operator_norm
/// [Wikipedia article on operator norm](https://en.wikipedia.org/wiki/Operator_norm)
pub trait OperationNorm {
/// the value of norm
type Output: RealScalar;
Expand Down
6 changes: 3 additions & 3 deletions src/qr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! QR decomposition
//!
//! https://en.wikipedia.org/wiki/QR_decomposition
//! [Wikipedia article on QR decomposition](https://en.wikipedia.org/wiki/QR_decomposition)

use ndarray::*;
use num_traits::Zero;
Expand Down Expand Up @@ -49,7 +49,7 @@ pub trait QRSquareInto: Sized {
/// QR decomposition for mutable reference of square matrix
pub trait QRSquareInplace: Sized {
type R;
fn qr_square_inplace<'a>(&'a mut self) -> Result<(&'a mut Self, Self::R)>;
fn qr_square_inplace(&mut self) -> Result<(&mut Self, Self::R)>;
}

impl<A, S> QRSquareInplace for ArrayBase<S, Ix2>
Expand All @@ -59,7 +59,7 @@ where
{
type R = Array2<A>;

fn qr_square_inplace<'a>(&'a mut self) -> Result<(&'a mut Self, Self::R)> {
fn qr_square_inplace(&mut self) -> Result<(&mut Self, Self::R)> {
let l = self.square_layout()?;
let r = unsafe { A::qr(l, self.as_allocated_mut()?)? };
let r: Array2<_> = into_matrix(l, r)?;
Expand Down
12 changes: 6 additions & 6 deletions src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ where
A: Scalar,
S: Data<Elem = A>,
{
fn solve_inplace<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solve_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
Expand All @@ -156,7 +156,7 @@ where
};
Ok(rhs)
}
fn solve_t_inplace<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solve_t_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
Expand All @@ -171,7 +171,7 @@ where
};
Ok(rhs)
}
fn solve_h_inplace<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solve_h_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
Expand All @@ -193,21 +193,21 @@ where
A: Scalar,
S: Data<Elem = A>,
{
fn solve_inplace<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solve_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize()?;
f.solve_inplace(rhs)
}
fn solve_t_inplace<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solve_t_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize()?;
f.solve_t_inplace(rhs)
}
fn solve_h_inplace<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solve_h_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
Expand Down
2 changes: 1 addition & 1 deletion src/solveh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ where
A: Scalar,
S: Data<Elem = A>,
{
fn solveh_inplace<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solveh_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
Expand Down
4 changes: 2 additions & 2 deletions src/svd.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! singular-value decomposition
//! Singular-value decomposition (SVD)
//!
//! https://en.wikipedia.org/wiki/Singular_value_decomposition
//! [Wikipedia article on SVD](https://en.wikipedia.org/wiki/Singular_value_decomposition)

use ndarray::*;

Expand Down
2 changes: 1 addition & 1 deletion src/triangular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ where
&self,
uplo: UPLO,
diag: Diag,
mut b: &'a mut ArrayBase<So, Ix2>,
b: &'a mut ArrayBase<So, Ix2>,
) -> Result<&'a mut ArrayBase<So, Ix2>> {
let la = self.layout()?;
let a_ = self.as_allocated()?;
Expand Down
14 changes: 7 additions & 7 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub use num_complex::Complex64 as c64;
///
/// You can use the following operations with `A: Scalar`:
///
/// - [abs](trait.Absolute.html#method.abs)
/// - [abs_sqr](trait.Absolute.html#tymethod.abs_sqr)
/// - [sqrt](trait.SquareRoot.html#tymethod.sqrt)
/// - [exp](trait.Exponential.html#tymethod.exp)
/// - [ln](trait.NaturalLogarithm.html#tymethod.ln)
/// - [conj](trait.Conjugate.html#tymethod.conj)
/// - [randn](trait.RandNormal.html#tymethod.randn)
/// - [`abs`](trait.Absolute.html#method.abs)
/// - [`abs_sqr`](trait.Absolute.html#tymethod.abs_sqr)
/// - [`sqrt`](trait.SquareRoot.html#tymethod.sqrt)
/// - [`exp`](trait.Exponential.html#tymethod.exp)
/// - [`ln`](trait.NaturalLogarithm.html#tymethod.ln)
/// - [`conj`](trait.Conjugate.html#tymethod.conj)
/// - [`randn`](trait.RandNormal.html#tymethod.randn)
///
pub trait Scalar
: LapackScalar
Expand Down
10 changes: 5 additions & 5 deletions tests/det.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ndarray_linalg::*;
use num_traits::{One, Zero};

/// Returns the matrix with the specified `row` and `col` removed.
fn matrix_minor<A, S>(a: ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A>
fn matrix_minor<A, S>(a: &ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A>
where
A: Scalar,
S: Data<Elem = A>,
Expand All @@ -27,7 +27,7 @@ where
///
/// Note: This implementation is written to be clearly correct so that it's
/// useful for verification, but it's very inefficient.
fn det_naive<A, S>(a: ArrayBase<S, Ix2>) -> A
fn det_naive<A, S>(a: &ArrayBase<S, Ix2>) -> A
where
A: Scalar,
S: Data<Elem = A>,
Expand All @@ -40,7 +40,7 @@ where
(0..cols)
.map(|col| {
let sign = if col % 2 == 0 { A::one() } else { -A::one() };
sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col)))
sign * a[(0, col)] * det_naive(&matrix_minor(a, (0, col)))
})
.fold(A::zero(), |sum, subdet| sum + subdet)
}
Expand Down Expand Up @@ -102,7 +102,7 @@ fn det() {
($elem:ty, $shape:expr, $rtol:expr) => {
let a: Array2<$elem> = random($shape);
println!("a = \n{:?}", a);
let det = det_naive(a.view());
let det = det_naive(&a);
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
assert_rclose!(a.det().unwrap(), det, $rtol);
Expand Down Expand Up @@ -131,7 +131,7 @@ fn det_nonsquare() {
}
}
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {
for &shape in &[dims.clone().into_shape(), dims.clone().f()] {
for &shape in &[dims.into_shape(), dims.f()] {
det_nonsquare!(f64, shape);
det_nonsquare!(f32, shape);
det_nonsquare!(c64, shape);
Expand Down
2 changes: 1 addition & 1 deletion tests/deth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn deth_nonsquare() {
}
}
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {
for &shape in &[dims.clone().into_shape(), dims.clone().f()] {
for &shape in &[dims.into_shape(), dims.f()] {
deth_nonsquare!(f64, shape);
deth_nonsquare!(f32, shape);
deth_nonsquare!(c64, shape);
Expand Down
24 changes: 12 additions & 12 deletions tests/qr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use ndarray::*;
use ndarray_linalg::*;
use std::cmp::min;

fn test(a: Array2<f64>, n: usize, m: usize) {
fn test(a: &Array2<f64>, n: usize, m: usize) {
let ans = a.clone();
println!("a = \n{:?}", &a);
println!("a = \n{:?}", a);
let (q, r): (Array2<_>, Array2<_>) = a.qr().unwrap();
println!("q = \n{:?}", &q);
println!("r = \n{:?}", &r);
Expand All @@ -18,9 +18,9 @@ fn test(a: Array2<f64>, n: usize, m: usize) {
assert_close_l2!(&r.clone().into_triangular(UPLO::Upper), &r, 1e-7);
}

fn test_square(a: Array2<f64>, n: usize, m: usize) {
fn test_square(a: &Array2<f64>, n: usize, m: usize) {
let ans = a.clone();
println!("a = \n{:?}", &a);
println!("a = \n{:?}", a);
let (q, r): (Array2<_>, Array2<_>) = a.qr_square().unwrap();
println!("q = \n{:?}", &q);
println!("r = \n{:?}", &r);
Expand All @@ -32,47 +32,47 @@ fn test_square(a: Array2<f64>, n: usize, m: usize) {
#[test]
fn qr_sq() {
let a = random((3, 3));
test_square(a, 3, 3);
test_square(&a, 3, 3);
}

#[test]
fn qr_sq_t() {
let a = random((3, 3).f());
test_square(a, 3, 3);
test_square(&a, 3, 3);
}

#[test]
fn qr_3x3() {
let a = random((3, 3));
test(a, 3, 3);
test(&a, 3, 3);
}

#[test]
fn qr_3x3_t() {
let a = random((3, 3).f());
test(a, 3, 3);
test(&a, 3, 3);
}

#[test]
fn qr_3x4() {
let a = random((3, 4));
test(a, 3, 4);
test(&a, 3, 4);
}

#[test]
fn qr_3x4_t() {
let a = random((3, 4).f());
test(a, 3, 4);
test(&a, 3, 4);
}

#[test]
fn qr_4x3() {
let a = random((4, 3));
test(a, 4, 3);
test(&a, 4, 3);
}

#[test]
fn qr_4x3_t() {
let a = random((4, 3).f());
test(a, 4, 3);
test(&a, 4, 3);
}
Loading