Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions ndarray-linalg/src/assert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ pub fn aclose<A: Scalar>(test: A, truth: A, atol: A::Real) {
}

/// check two arrays are close in maximum norm
pub fn close_max<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: A::Real)
pub fn close_max<A, D>(test: &ArrayRef<A, D>, truth: &ArrayRef<A, D>, atol: A::Real)
where
A: Scalar + Lapack,
S1: Data<Elem = A>,
S2: Data<Elem = A>,
D: Dimension,
D::Pattern: PartialEq + Debug,
{
Expand All @@ -48,11 +46,9 @@ where
}

/// check two arrays are close in L1 norm
pub fn close_l1<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
pub fn close_l1<A, D>(test: &ArrayRef<A, D>, truth: &ArrayRef<A, D>, rtol: A::Real)
where
A: Scalar + Lapack,
S1: Data<Elem = A>,
S2: Data<Elem = A>,
D: Dimension,
D::Pattern: PartialEq + Debug,
{
Expand All @@ -67,11 +63,9 @@ where
}

/// check two arrays are close in L2 norm
pub fn close_l2<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
pub fn close_l2<A, D>(test: &ArrayRef<A, D>, truth: &ArrayRef<A, D>, rtol: A::Real)
where
A: Scalar + Lapack,
S1: Data<Elem = A>,
S2: Data<Elem = A>,
D: Dimension,
D::Pattern: PartialEq + Debug,
{
Expand Down
41 changes: 10 additions & 31 deletions ndarray-linalg/src/cholesky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,7 @@ where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solvec_inplace<'a, Sb>(
&self,
b: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
fn solvec_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
A::solve_cholesky(
self.factor.square_layout()?,
self.uplo,
Expand Down Expand Up @@ -225,10 +219,9 @@ pub trait CholeskyInplace {
fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self>;
}

impl<A, S> Cholesky for ArrayBase<S, Ix2>
impl<A> Cholesky for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
type Output = Array2<A>;

Expand All @@ -251,10 +244,9 @@ where
}
}

impl<A, S> CholeskyInplace for ArrayBase<S, Ix2>
impl<A> CholeskyInplace for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> {
A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?;
Expand Down Expand Up @@ -301,10 +293,9 @@ where
}
}

impl<A, Si> FactorizeC<OwnedRepr<A>> for ArrayBase<Si, Ix2>
impl<A> FactorizeC<OwnedRepr<A>> for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
Si: Data<Elem = A>,
{
fn factorizec(&self, uplo: UPLO) -> Result<CholeskyFactorized<OwnedRepr<A>>> {
Ok(CholeskyFactorized {
Expand All @@ -320,7 +311,7 @@ pub trait SolveC<A: Scalar> {
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
/// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
/// the argument, and `x` is the successful result.
fn solvec<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
fn solvec(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solvec_inplace(&mut b)?;
Ok(b)
Expand All @@ -339,24 +330,14 @@ pub trait SolveC<A: Scalar> {
/// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
/// the argument, and `x` is the successful result. The value of `x` is
/// also assigned to the argument.
fn solvec_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
) -> Result<&'a mut ArrayBase<S, Ix1>>;
fn solvec_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
}

impl<A, S> SolveC<A> for ArrayBase<S, Ix2>
impl<A> SolveC<A> for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solvec_inplace<'a, Sb>(
&self,
b: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
fn solvec_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
self.factorizec(UPLO::Upper)?.solvec_inplace(b)
}
}
Expand All @@ -377,10 +358,9 @@ pub trait InverseCInto {
fn invc_into(self) -> Result<Self::Output>;
}

impl<A, S> InverseC for ArrayBase<S, Ix2>
impl<A> InverseC for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
type Output = Array2<A>;

Expand Down Expand Up @@ -435,10 +415,9 @@ pub trait DeterminantCInto {
fn ln_detc_into(self) -> Self::Output;
}

impl<A, S> DeterminantC for ArrayBase<S, Ix2>
impl<A> DeterminantC for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
type Output = Result<<A as Scalar>::Real>;

Expand Down
17 changes: 7 additions & 10 deletions ndarray-linalg/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,31 @@ where
}
}

pub fn replicate<A, Sv, So, D>(a: &ArrayBase<Sv, D>) -> ArrayBase<So, D>
pub fn replicate<A, S, D>(a: &ArrayRef<A, D>) -> ArrayBase<S, D>
where
A: Copy,
Sv: Data<Elem = A>,
So: DataOwned<Elem = A> + DataMut,
S: DataOwned<Elem = A> + DataMut,
D: Dimension,
{
unsafe {
let ret = ArrayBase::<So, D>::build_uninit(a.dim(), |view| {
let ret = ArrayBase::<S, D>::build_uninit(a.dim(), |view| {
a.assign_to(view);
});
ret.assume_init()
}
}

fn clone_with_layout<A, Si, So>(l: MatrixLayout, a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
fn clone_with_layout<A, S>(l: MatrixLayout, a: &ArrayRef<A, Ix2>) -> ArrayBase<S, Ix2>
where
A: Copy,
Si: Data<Elem = A>,
So: DataOwned<Elem = A> + DataMut,
S: DataOwned<Elem = A> + DataMut,
{
let shape_builder = match l {
MatrixLayout::C { row, lda } => (row as usize, lda as usize).set_f(false),
MatrixLayout::F { col, lda } => (lda as usize, col as usize).set_f(true),
};
unsafe {
let ret = ArrayBase::<So, _>::build_uninit(shape_builder, |view| {
let ret = ArrayBase::<S, _>::build_uninit(shape_builder, |view| {
a.assign_to(view);
});
ret.assume_init()
Expand Down Expand Up @@ -119,10 +117,9 @@ where
/// data in the triangular portion corresponding to `uplo`.
///
/// ***Panics*** if `a` is not square.
pub(crate) fn triangular_fill_hermitian<A, S>(a: &mut ArrayBase<S, Ix2>, uplo: UPLO)
pub(crate) fn triangular_fill_hermitian<A>(a: &mut ArrayRef<A, Ix2>, uplo: UPLO)
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
assert!(a.is_square());
match uplo {
Expand Down
7 changes: 2 additions & 5 deletions ndarray-linalg/src/diagonal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl<S: Data> IntoDiagonal<S> for ArrayBase<S, Ix1> {
}
}

impl<A, S: Data<Elem = A>> AsDiagonal<A> for ArrayBase<S, Ix1> {
impl<A> AsDiagonal<A> for ArrayRef<A, Ix1> {
fn as_diagonal(&self) -> Diagonal<ViewRepr<&A>> {
Diagonal { diag: self.view() }
}
Expand All @@ -37,10 +37,7 @@ where
{
type Elem = A;

fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
where
S: DataMut<Elem = A>,
{
fn apply_mut(&self, a: &mut ArrayRef<A, Ix1>) {
for (val, d) in a.iter_mut().zip(self.diag.iter()) {
*val *= *d;
}
Expand Down
92 changes: 78 additions & 14 deletions ndarray-linalg/src/eig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ pub trait Eig {
fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)>;
}

impl<A, S> Eig for ArrayBase<S, Ix2>
impl<A> Eig for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
type EigVal = Array1<A::Complex>;
type EigVec = Array2<A::Complex>;
Expand All @@ -65,10 +64,9 @@ pub trait EigVals {
fn eigvals(&self) -> Result<Self::EigVal>;
}

impl<A, S> EigVals for ArrayBase<S, Ix2>
impl<A> EigVals for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
type EigVal = Array1<A::Complex>;

Expand Down Expand Up @@ -127,27 +125,71 @@ pub trait EigGeneralized {
/// computing the eigenvalues as α/β. If `None`, no approximate comparisons to zero will be
/// made.
fn eig_generalized(
&self,
self,
thresh_opt: Option<Self::Real>,
) -> Result<(Self::EigVal, Self::EigVec)>;
}

impl<A, S> EigGeneralized for (ArrayBase<S, Ix2>, ArrayBase<S, Ix2>)
/// Turn arrays, references to arrays, and [`ArrayRef`]s into owned arrays
pub trait MaybeOwnedMatrix {
type Elem;

/// Convert into an owned array, cloning only when necessary.
fn into_owned(self) -> Array2<Self::Elem>;
}

impl<S> MaybeOwnedMatrix for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
S: Data,
S::Elem: Clone,
{
type EigVal = Array1<GeneralizedEigenvalue<A::Complex>>;
type EigVec = Array2<A::Complex>;
type Real = A::Real;
type Elem = S::Elem;

fn into_owned(self) -> Array2<S::Elem> {
ArrayBase::into_owned(self)
}
}

impl<S> MaybeOwnedMatrix for &ArrayBase<S, Ix2>
where
S: Data,
S::Elem: Clone,
{
type Elem = S::Elem;

fn into_owned(self) -> Array2<S::Elem> {
self.to_owned()
}
}

impl<A> MaybeOwnedMatrix for &ArrayRef2<A>
where
A: Clone,
{
type Elem = A;

fn into_owned(self) -> Array2<A> {
self.to_owned()
}
}

impl<T1, T2> EigGeneralized for (T1, T2)
where
T1: MaybeOwnedMatrix,
T1::Elem: Lapack + Scalar,
T2: MaybeOwnedMatrix<Elem = T1::Elem>,
{
type EigVal = Array1<GeneralizedEigenvalue<<T1::Elem as Scalar>::Complex>>;
type EigVec = Array2<<T1::Elem as Scalar>::Complex>;
type Real = <T1::Elem as Scalar>::Real;

fn eig_generalized(
&self,
self,
thresh_opt: Option<Self::Real>,
) -> Result<(Self::EigVal, Self::EigVec)> {
let (mut a, mut b) = (self.0.to_owned(), self.1.to_owned());
let (mut a, mut b) = (self.0.into_owned(), self.1.into_owned());
let layout = a.square_layout()?;
let (s, t) = A::eig_generalized(
let (s, t) = T1::Elem::eig_generalized(
true,
layout,
a.as_allocated_mut()?,
Expand All @@ -161,3 +203,25 @@ where
))
}
}

#[cfg(test)]
mod tests {
use crate::MaybeOwnedMatrix;

#[test]
fn test_maybe_owned_matrix() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let a_ptr = a.as_ptr();
let a1 = MaybeOwnedMatrix::into_owned(a);
assert_eq!(a_ptr, a1.as_ptr());

let b = a1.clone();
let b1 = MaybeOwnedMatrix::into_owned(&b);
assert_eq!(b, b1);
assert_ne!(b.as_ptr(), b1.as_ptr());

let b2 = MaybeOwnedMatrix::into_owned(&*b);
assert_eq!(b, b2);
assert_ne!(b.as_ptr(), b2.as_ptr());
}
}
3 changes: 1 addition & 2 deletions ndarray-linalg/src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ where
}
}

impl<A, S> EighInplace for ArrayBase<S, Ix2>
impl<A> EighInplace for ArrayRef<A, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
type EigVal = Array1<A::Real>;

Expand Down
7 changes: 3 additions & 4 deletions ndarray-linalg/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ use super::qr::*;
use super::types::*;

/// Hermite conjugate matrix
pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
pub fn conjugate<A, S>(a: &ArrayRef<A, Ix2>) -> ArrayBase<S, Ix2>
where
A: Scalar,
Si: Data<Elem = A>,
So: DataOwned<Elem = A> + DataMut,
S: DataOwned<Elem = A> + DataMut,
{
let mut a: ArrayBase<So, Ix2> = replicate(&a.t());
let mut a: ArrayBase<S, Ix2> = replicate(&a.t());
for val in a.iter_mut() {
*val = val.conj();
}
Expand Down
Loading