Skip to content

Commit e34747c

Browse files
impl and tests added
1 parent 41bace1 commit e34747c

File tree

1 file changed

+127
-1
lines changed

1 file changed

+127
-1
lines changed

src/linalg/impl_linalg.rs

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ where
353353
///
354354
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
355355
///
356-
/// **Panics** if broadcasting isnt possible.
356+
/// **Panics** if broadcasting isn't possible.
357357
#[track_caller]
358358
pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
359359
where
@@ -1067,3 +1067,129 @@ mod blas_tests
10671067
}
10681068
}
10691069
}
1070+
1071+
impl<A, S, S2> Dot<ArrayBase<S2, IxDyn>> for ArrayBase<S, IxDyn>
1072+
where
1073+
S: Data<Elem = A>,
1074+
S2: Data<Elem = A>,
1075+
A: LinalgScalar,
1076+
{
1077+
type Output = Array<A, IxDyn>;
1078+
1079+
fn dot(&self, rhs: &ArrayBase<S2, IxDyn>) -> Self::Output {
1080+
match (self.ndim(), rhs.ndim()) {
1081+
(1, 1) => {
1082+
// Vector-vector dot product
1083+
if self.len() != rhs.len() {
1084+
panic!("Vector lengths must match for dot product");
1085+
}
1086+
let a = self.view().into_dimensionality::<Ix1>().unwrap();
1087+
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
1088+
let result = a.dot(&b);
1089+
ArrayD::from_elem(vec![], result)
1090+
}
1091+
(2, 2) => {
1092+
// Matrix-matrix multiplication
1093+
let a = self.view().into_dimensionality::<Ix2>().unwrap();
1094+
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
1095+
let result = a.dot(&b);
1096+
result.into_dimensionality::<IxDyn>().unwrap()
1097+
}
1098+
(2, 1) => {
1099+
// Matrix-vector multiplication
1100+
let a = self.view().into_dimensionality::<Ix2>().unwrap();
1101+
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
1102+
let result = a.dot(&b);
1103+
result.into_dimensionality::<IxDyn>().unwrap()
1104+
}
1105+
(1, 2) => {
1106+
// Vector-matrix multiplication
1107+
let a = self.view().into_dimensionality::<Ix1>().unwrap();
1108+
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
1109+
let result = a.dot(&b);
1110+
result.into_dimensionality::<IxDyn>().unwrap()
1111+
}
1112+
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
1113+
}
1114+
}
1115+
}
1116+
1117+
#[cfg(test)]
1118+
mod arrayd_dot_tests {
1119+
use super::*;
1120+
use crate::ArrayD;
1121+
1122+
#[test]
1123+
fn test_arrayd_dot_2d() {
1124+
// Test case from the original issue
1125+
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
1126+
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
1127+
1128+
let result = mat1.dot(&mat2);
1129+
1130+
// Verify the result is correct
1131+
assert_eq!(result.ndim(), 2);
1132+
assert_eq!(result.shape(), &[3, 3]);
1133+
1134+
// Compare with Array2 implementation
1135+
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
1136+
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1137+
let expected = mat1_2d.dot(&mat2_2d);
1138+
1139+
assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
1140+
}
1141+
1142+
#[test]
1143+
fn test_arrayd_dot_1d() {
1144+
// Test 1D array dot product
1145+
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
1146+
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();
1147+
1148+
let result = vec1.dot(&vec2);
1149+
1150+
// Verify scalar result
1151+
assert_eq!(result.ndim(), 0);
1152+
assert_eq!(result.shape(), &[]);
1153+
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
1154+
}
1155+
1156+
#[test]
1157+
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
1158+
fn test_arrayd_dot_3d() {
1159+
// Test that 3D arrays are not supported
1160+
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
1161+
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
1162+
1163+
let _result = arr1.dot(&arr2); // Should panic
1164+
}
1165+
1166+
#[test]
1167+
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
1168+
fn test_arrayd_dot_incompatible_dims() {
1169+
// Test arrays with incompatible dimensions
1170+
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
1171+
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();
1172+
1173+
let _result = arr1.dot(&arr2); // Should panic
1174+
}
1175+
1176+
#[test]
1177+
fn test_arrayd_dot_matrix_vector() {
1178+
// Test matrix-vector multiplication
1179+
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1180+
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();
1181+
1182+
let result = mat.dot(&vec);
1183+
1184+
// Verify result
1185+
assert_eq!(result.ndim(), 1);
1186+
assert_eq!(result.shape(), &[3]);
1187+
1188+
// Compare with Array2 implementation
1189+
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1190+
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
1191+
let expected = mat_2d.dot(&vec_1d);
1192+
1193+
assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
1194+
}
1195+
}

0 commit comments

Comments
 (0)