Skip to content

Commit 4520265

Browse files
authored
Merge pull request #415 from akern40/reference-type-full
2 parents 49e964b + 729edc0 commit 4520265

23 files changed

+262
-444
lines changed

ndarray-linalg/src/assert.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ pub fn aclose<A: Scalar>(test: A, truth: A, atol: A::Real) {
2929
}
3030

3131
/// check two arrays are close in maximum norm
32-
pub fn close_max<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: A::Real)
32+
pub fn close_max<A, D>(test: &ArrayRef<A, D>, truth: &ArrayRef<A, D>, atol: A::Real)
3333
where
3434
A: Scalar + Lapack,
35-
S1: Data<Elem = A>,
36-
S2: Data<Elem = A>,
3735
D: Dimension,
3836
D::Pattern: PartialEq + Debug,
3937
{
@@ -48,11 +46,9 @@ where
4846
}
4947

5048
/// check two arrays are close in L1 norm
51-
pub fn close_l1<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
49+
pub fn close_l1<A, D>(test: &ArrayRef<A, D>, truth: &ArrayRef<A, D>, rtol: A::Real)
5250
where
5351
A: Scalar + Lapack,
54-
S1: Data<Elem = A>,
55-
S2: Data<Elem = A>,
5652
D: Dimension,
5753
D::Pattern: PartialEq + Debug,
5854
{
@@ -67,11 +63,9 @@ where
6763
}
6864

6965
/// check two arrays are close in L2 norm
70-
pub fn close_l2<A, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: A::Real)
66+
pub fn close_l2<A, D>(test: &ArrayRef<A, D>, truth: &ArrayRef<A, D>, rtol: A::Real)
7167
where
7268
A: Scalar + Lapack,
73-
S1: Data<Elem = A>,
74-
S2: Data<Elem = A>,
7569
D: Dimension,
7670
D::Pattern: PartialEq + Debug,
7771
{

ndarray-linalg/src/cholesky.rs

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,7 @@ where
166166
A: Scalar + Lapack,
167167
S: Data<Elem = A>,
168168
{
169-
fn solvec_inplace<'a, Sb>(
170-
&self,
171-
b: &'a mut ArrayBase<Sb, Ix1>,
172-
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
173-
where
174-
Sb: DataMut<Elem = A>,
175-
{
169+
fn solvec_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
176170
A::solve_cholesky(
177171
self.factor.square_layout()?,
178172
self.uplo,
@@ -225,10 +219,9 @@ pub trait CholeskyInplace {
225219
fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self>;
226220
}
227221

228-
impl<A, S> Cholesky for ArrayBase<S, Ix2>
222+
impl<A> Cholesky for ArrayRef<A, Ix2>
229223
where
230224
A: Scalar + Lapack,
231-
S: Data<Elem = A>,
232225
{
233226
type Output = Array2<A>;
234227

@@ -251,10 +244,9 @@ where
251244
}
252245
}
253246

254-
impl<A, S> CholeskyInplace for ArrayBase<S, Ix2>
247+
impl<A> CholeskyInplace for ArrayRef<A, Ix2>
255248
where
256249
A: Scalar + Lapack,
257-
S: DataMut<Elem = A>,
258250
{
259251
fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> {
260252
A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?;
@@ -301,10 +293,9 @@ where
301293
}
302294
}
303295

304-
impl<A, Si> FactorizeC<OwnedRepr<A>> for ArrayBase<Si, Ix2>
296+
impl<A> FactorizeC<OwnedRepr<A>> for ArrayRef<A, Ix2>
305297
where
306298
A: Scalar + Lapack,
307-
Si: Data<Elem = A>,
308299
{
309300
fn factorizec(&self, uplo: UPLO) -> Result<CholeskyFactorized<OwnedRepr<A>>> {
310301
Ok(CholeskyFactorized {
@@ -320,7 +311,7 @@ pub trait SolveC<A: Scalar> {
320311
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
321312
/// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
322313
/// the argument, and `x` is the successful result.
323-
fn solvec<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
314+
fn solvec(&self, b: &ArrayRef<A, Ix1>) -> Result<Array1<A>> {
324315
let mut b = replicate(b);
325316
self.solvec_inplace(&mut b)?;
326317
Ok(b)
@@ -339,24 +330,14 @@ pub trait SolveC<A: Scalar> {
339330
/// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
340331
/// the argument, and `x` is the successful result. The value of `x` is
341332
/// also assigned to the argument.
342-
fn solvec_inplace<'a, S: DataMut<Elem = A>>(
343-
&self,
344-
b: &'a mut ArrayBase<S, Ix1>,
345-
) -> Result<&'a mut ArrayBase<S, Ix1>>;
333+
fn solvec_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>>;
346334
}
347335

348-
impl<A, S> SolveC<A> for ArrayBase<S, Ix2>
336+
impl<A> SolveC<A> for ArrayRef<A, Ix2>
349337
where
350338
A: Scalar + Lapack,
351-
S: Data<Elem = A>,
352339
{
353-
fn solvec_inplace<'a, Sb>(
354-
&self,
355-
b: &'a mut ArrayBase<Sb, Ix1>,
356-
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
357-
where
358-
Sb: DataMut<Elem = A>,
359-
{
340+
fn solvec_inplace<'a>(&self, b: &'a mut ArrayRef<A, Ix1>) -> Result<&'a mut ArrayRef<A, Ix1>> {
360341
self.factorizec(UPLO::Upper)?.solvec_inplace(b)
361342
}
362343
}
@@ -377,10 +358,9 @@ pub trait InverseCInto {
377358
fn invc_into(self) -> Result<Self::Output>;
378359
}
379360

380-
impl<A, S> InverseC for ArrayBase<S, Ix2>
361+
impl<A> InverseC for ArrayRef<A, Ix2>
381362
where
382363
A: Scalar + Lapack,
383-
S: Data<Elem = A>,
384364
{
385365
type Output = Array2<A>;
386366

@@ -435,10 +415,9 @@ pub trait DeterminantCInto {
435415
fn ln_detc_into(self) -> Self::Output;
436416
}
437417

438-
impl<A, S> DeterminantC for ArrayBase<S, Ix2>
418+
impl<A> DeterminantC for ArrayRef<A, Ix2>
439419
where
440420
A: Scalar + Lapack,
441-
S: Data<Elem = A>,
442421
{
443422
type Output = Result<<A as Scalar>::Real>;
444423

ndarray-linalg/src/convert.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,31 @@ where
4646
}
4747
}
4848

49-
pub fn replicate<A, Sv, So, D>(a: &ArrayBase<Sv, D>) -> ArrayBase<So, D>
49+
pub fn replicate<A, S, D>(a: &ArrayRef<A, D>) -> ArrayBase<S, D>
5050
where
5151
A: Copy,
52-
Sv: Data<Elem = A>,
53-
So: DataOwned<Elem = A> + DataMut,
52+
S: DataOwned<Elem = A> + DataMut,
5453
D: Dimension,
5554
{
5655
unsafe {
57-
let ret = ArrayBase::<So, D>::build_uninit(a.dim(), |view| {
56+
let ret = ArrayBase::<S, D>::build_uninit(a.dim(), |view| {
5857
a.assign_to(view);
5958
});
6059
ret.assume_init()
6160
}
6261
}
6362

64-
fn clone_with_layout<A, Si, So>(l: MatrixLayout, a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
63+
fn clone_with_layout<A, S>(l: MatrixLayout, a: &ArrayRef<A, Ix2>) -> ArrayBase<S, Ix2>
6564
where
6665
A: Copy,
67-
Si: Data<Elem = A>,
68-
So: DataOwned<Elem = A> + DataMut,
66+
S: DataOwned<Elem = A> + DataMut,
6967
{
7068
let shape_builder = match l {
7169
MatrixLayout::C { row, lda } => (row as usize, lda as usize).set_f(false),
7270
MatrixLayout::F { col, lda } => (lda as usize, col as usize).set_f(true),
7371
};
7472
unsafe {
75-
let ret = ArrayBase::<So, _>::build_uninit(shape_builder, |view| {
73+
let ret = ArrayBase::<S, _>::build_uninit(shape_builder, |view| {
7674
a.assign_to(view);
7775
});
7876
ret.assume_init()
@@ -119,10 +117,9 @@ where
119117
/// data in the triangular portion corresponding to `uplo`.
120118
///
121119
/// ***Panics*** if `a` is not square.
122-
pub(crate) fn triangular_fill_hermitian<A, S>(a: &mut ArrayBase<S, Ix2>, uplo: UPLO)
120+
pub(crate) fn triangular_fill_hermitian<A>(a: &mut ArrayRef<A, Ix2>, uplo: UPLO)
123121
where
124122
A: Scalar + Lapack,
125-
S: DataMut<Elem = A>,
126123
{
127124
assert!(a.is_square());
128125
match uplo {

ndarray-linalg/src/diagonal.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl<S: Data> IntoDiagonal<S> for ArrayBase<S, Ix1> {
2424
}
2525
}
2626

27-
impl<A, S: Data<Elem = A>> AsDiagonal<A> for ArrayBase<S, Ix1> {
27+
impl<A> AsDiagonal<A> for ArrayRef<A, Ix1> {
2828
fn as_diagonal(&self) -> Diagonal<ViewRepr<&A>> {
2929
Diagonal { diag: self.view() }
3030
}
@@ -37,10 +37,7 @@ where
3737
{
3838
type Elem = A;
3939

40-
fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
41-
where
42-
S: DataMut<Elem = A>,
43-
{
40+
fn apply_mut(&self, a: &mut ArrayRef<A, Ix1>) {
4441
for (val, d) in a.iter_mut().zip(self.diag.iter()) {
4542
*val *= *d;
4643
}

ndarray-linalg/src/eig.rs

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ pub trait Eig {
3939
fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)>;
4040
}
4141

42-
impl<A, S> Eig for ArrayBase<S, Ix2>
42+
impl<A> Eig for ArrayRef<A, Ix2>
4343
where
4444
A: Scalar + Lapack,
45-
S: Data<Elem = A>,
4645
{
4746
type EigVal = Array1<A::Complex>;
4847
type EigVec = Array2<A::Complex>;
@@ -65,10 +64,9 @@ pub trait EigVals {
6564
fn eigvals(&self) -> Result<Self::EigVal>;
6665
}
6766

68-
impl<A, S> EigVals for ArrayBase<S, Ix2>
67+
impl<A> EigVals for ArrayRef<A, Ix2>
6968
where
7069
A: Scalar + Lapack,
71-
S: Data<Elem = A>,
7270
{
7371
type EigVal = Array1<A::Complex>;
7472

@@ -127,27 +125,71 @@ pub trait EigGeneralized {
127125
/// computing the eigenvalues as α/β. If `None`, no approximate comparisons to zero will be
128126
/// made.
129127
fn eig_generalized(
130-
&self,
128+
self,
131129
thresh_opt: Option<Self::Real>,
132130
) -> Result<(Self::EigVal, Self::EigVec)>;
133131
}
134132

135-
impl<A, S> EigGeneralized for (ArrayBase<S, Ix2>, ArrayBase<S, Ix2>)
133+
/// Turn arrays, references to arrays, and [`ArrayRef`]s into owned arrays
134+
pub trait MaybeOwnedMatrix {
135+
type Elem;
136+
137+
/// Convert into an owned array, cloning only when necessary.
138+
fn into_owned(self) -> Array2<Self::Elem>;
139+
}
140+
141+
impl<S> MaybeOwnedMatrix for ArrayBase<S, Ix2>
136142
where
137-
A: Scalar + Lapack,
138-
S: Data<Elem = A>,
143+
S: Data,
144+
S::Elem: Clone,
139145
{
140-
type EigVal = Array1<GeneralizedEigenvalue<A::Complex>>;
141-
type EigVec = Array2<A::Complex>;
142-
type Real = A::Real;
146+
type Elem = S::Elem;
147+
148+
fn into_owned(self) -> Array2<S::Elem> {
149+
ArrayBase::into_owned(self)
150+
}
151+
}
152+
153+
impl<S> MaybeOwnedMatrix for &ArrayBase<S, Ix2>
154+
where
155+
S: Data,
156+
S::Elem: Clone,
157+
{
158+
type Elem = S::Elem;
159+
160+
fn into_owned(self) -> Array2<S::Elem> {
161+
self.to_owned()
162+
}
163+
}
164+
165+
impl<A> MaybeOwnedMatrix for &ArrayRef2<A>
166+
where
167+
A: Clone,
168+
{
169+
type Elem = A;
170+
171+
fn into_owned(self) -> Array2<A> {
172+
self.to_owned()
173+
}
174+
}
175+
176+
impl<T1, T2> EigGeneralized for (T1, T2)
177+
where
178+
T1: MaybeOwnedMatrix,
179+
T1::Elem: Lapack + Scalar,
180+
T2: MaybeOwnedMatrix<Elem = T1::Elem>,
181+
{
182+
type EigVal = Array1<GeneralizedEigenvalue<<T1::Elem as Scalar>::Complex>>;
183+
type EigVec = Array2<<T1::Elem as Scalar>::Complex>;
184+
type Real = <T1::Elem as Scalar>::Real;
143185

144186
fn eig_generalized(
145-
&self,
187+
self,
146188
thresh_opt: Option<Self::Real>,
147189
) -> Result<(Self::EigVal, Self::EigVec)> {
148-
let (mut a, mut b) = (self.0.to_owned(), self.1.to_owned());
190+
let (mut a, mut b) = (self.0.into_owned(), self.1.into_owned());
149191
let layout = a.square_layout()?;
150-
let (s, t) = A::eig_generalized(
192+
let (s, t) = T1::Elem::eig_generalized(
151193
true,
152194
layout,
153195
a.as_allocated_mut()?,
@@ -161,3 +203,25 @@ where
161203
))
162204
}
163205
}
206+
207+
#[cfg(test)]
208+
mod tests {
209+
use crate::MaybeOwnedMatrix;
210+
211+
#[test]
212+
fn test_maybe_owned_matrix() {
213+
let a = array![[1.0, 2.0], [3.0, 4.0]];
214+
let a_ptr = a.as_ptr();
215+
let a1 = MaybeOwnedMatrix::into_owned(a);
216+
assert_eq!(a_ptr, a1.as_ptr());
217+
218+
let b = a1.clone();
219+
let b1 = MaybeOwnedMatrix::into_owned(&b);
220+
assert_eq!(b, b1);
221+
assert_ne!(b.as_ptr(), b1.as_ptr());
222+
223+
let b2 = MaybeOwnedMatrix::into_owned(&*b);
224+
assert_eq!(b, b2);
225+
assert_ne!(b.as_ptr(), b2.as_ptr());
226+
}
227+
}

ndarray-linalg/src/eigh.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,9 @@ where
117117
}
118118
}
119119

120-
impl<A, S> EighInplace for ArrayBase<S, Ix2>
120+
impl<A> EighInplace for ArrayRef<A, Ix2>
121121
where
122122
A: Scalar + Lapack,
123-
S: DataMut<Elem = A>,
124123
{
125124
type EigVal = Array1<A::Real>;
126125

ndarray-linalg/src/generate.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@ use super::qr::*;
99
use super::types::*;
1010

1111
/// Hermite conjugate matrix
12-
pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
12+
pub fn conjugate<A, S>(a: &ArrayRef<A, Ix2>) -> ArrayBase<S, Ix2>
1313
where
1414
A: Scalar,
15-
Si: Data<Elem = A>,
16-
So: DataOwned<Elem = A> + DataMut,
15+
S: DataOwned<Elem = A> + DataMut,
1716
{
18-
let mut a: ArrayBase<So, Ix2> = replicate(&a.t());
17+
let mut a: ArrayBase<S, Ix2> = replicate(&a.t());
1918
for val in a.iter_mut() {
2019
*val = val.conj();
2120
}

0 commit comments

Comments
 (0)