Skip to content

Commit ba3c029

Browse files
move tests into blas-tests
1 parent 28cc3b0 commit ba3c029

File tree

2 files changed

+76
-102
lines changed

2 files changed

+76
-102
lines changed

crates/blas-tests/tests/dyn.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
extern crate blas_src;
2+
use ndarray::{Array1, Array2, ArrayD, linalg::Dot, Ix1, Ix2};
3+
4+
#[test]
5+
fn test_arrayd_dot_2d() {
6+
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
7+
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
8+
9+
let result = mat1.dot(&mat2);
10+
11+
// Verify the result is correct
12+
assert_eq!(result.ndim(), 2);
13+
assert_eq!(result.shape(), &[3, 3]);
14+
15+
// Compare with Array2 implementation
16+
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
17+
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
18+
let expected = mat1_2d.dot(&mat2_2d);
19+
20+
assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
21+
}
22+
23+
#[test]
24+
fn test_arrayd_dot_1d() {
25+
// Test 1D array dot product
26+
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
27+
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();
28+
29+
let result = vec1.dot(&vec2);
30+
31+
// Verify scalar result
32+
assert_eq!(result.ndim(), 0);
33+
assert_eq!(result.shape(), &[]);
34+
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
35+
}
36+
37+
#[test]
38+
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
39+
fn test_arrayd_dot_3d() {
40+
// Test that 3D arrays are not supported
41+
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
42+
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
43+
44+
let _result = arr1.dot(&arr2); // Should panic
45+
}
46+
47+
#[test]
48+
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
49+
fn test_arrayd_dot_incompatible_dims() {
50+
// Test arrays with incompatible dimensions
51+
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
52+
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();
53+
54+
let _result = arr1.dot(&arr2); // Should panic
55+
}
56+
57+
#[test]
58+
fn test_arrayd_dot_matrix_vector() {
59+
// Test matrix-vector multiplication
60+
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
61+
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();
62+
63+
let result = mat.dot(&vec);
64+
65+
// Verify result
66+
assert_eq!(result.ndim(), 1);
67+
assert_eq!(result.shape(), &[3]);
68+
69+
// Compare with Array2 implementation
70+
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
71+
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
72+
let expected = mat_2d.dot(&vec_1d);
73+
74+
assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
75+
}

src/linalg/impl_linalg.rs

Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,21 +1089,6 @@ mod blas_tests
10891089
/// - The array shapes are incompatible for the operation
10901090
/// - For vector dot product: the vectors have different lengths
10911091
///
1092-
/// # Examples
1093-
///
1094-
/// ```
1095-
/// use ndarray::{ArrayD, linalg::Dot};
1096-
///
1097-
/// // Matrix multiplication
1098-
/// let a = ArrayD::from_shape_vec(vec![2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
1099-
/// let b = ArrayD::from_shape_vec(vec![3, 2], vec![1., 2., 3., 4., 5., 6.]).unwrap();
1100-
/// let c = a.dot(&b);
1101-
///
1102-
/// // Vector dot product
1103-
/// let v1 = ArrayD::from_shape_vec(vec![3], vec![1., 2., 3.]).unwrap();
1104-
/// let v2 = ArrayD::from_shape_vec(vec![3], vec![4., 5., 6.]).unwrap();
1105-
/// let scalar = v1.dot(&v2);
1106-
/// ```
11071092
impl<A, S, S2> Dot<ArrayBase<S2, IxDyn>> for ArrayBase<S, IxDyn>
11081093
where
11091094
S: Data<Elem = A>,
@@ -1145,90 +1130,4 @@ where
11451130
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
11461131
}
11471132
}
1148-
}
1149-
1150-
#[cfg(test)]
1151-
mod arrayd_dot_tests
1152-
{
1153-
use super::*;
1154-
use crate::ArrayD;
1155-
1156-
#[test]
1157-
fn test_arrayd_dot_2d()
1158-
{
1159-
// Test case from the original issue
1160-
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
1161-
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
1162-
1163-
let result = mat1.dot(&mat2);
1164-
1165-
// Verify the result is correct
1166-
assert_eq!(result.ndim(), 2);
1167-
assert_eq!(result.shape(), &[3, 3]);
1168-
1169-
// Compare with Array2 implementation
1170-
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
1171-
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1172-
let expected = mat1_2d.dot(&mat2_2d);
1173-
1174-
assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
1175-
}
1176-
1177-
#[test]
1178-
fn test_arrayd_dot_1d()
1179-
{
1180-
// Test 1D array dot product
1181-
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
1182-
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();
1183-
1184-
let result = vec1.dot(&vec2);
1185-
1186-
// Verify scalar result
1187-
assert_eq!(result.ndim(), 0);
1188-
assert_eq!(result.shape(), &[]);
1189-
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
1190-
}
1191-
1192-
#[test]
1193-
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
1194-
fn test_arrayd_dot_3d()
1195-
{
1196-
// Test that 3D arrays are not supported
1197-
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
1198-
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
1199-
1200-
let _result = arr1.dot(&arr2); // Should panic
1201-
}
1202-
1203-
#[test]
1204-
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
1205-
fn test_arrayd_dot_incompatible_dims()
1206-
{
1207-
// Test arrays with incompatible dimensions
1208-
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
1209-
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();
1210-
1211-
let _result = arr1.dot(&arr2); // Should panic
1212-
}
1213-
1214-
#[test]
1215-
fn test_arrayd_dot_matrix_vector()
1216-
{
1217-
// Test matrix-vector multiplication
1218-
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1219-
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();
1220-
1221-
let result = mat.dot(&vec);
1222-
1223-
// Verify result
1224-
assert_eq!(result.ndim(), 1);
1225-
assert_eq!(result.shape(), &[3]);
1226-
1227-
// Compare with Array2 implementation
1228-
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1229-
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
1230-
let expected = mat_2d.dot(&vec_1d);
1231-
1232-
assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
1233-
}
1234-
}
1133+
}

0 commit comments

Comments
 (0)