@@ -353,7 +353,7 @@ where
353353 ///
354354 /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
355355 ///
356- /// **Panics** if broadcasting isn’ t possible.
356+ /// **Panics** if broadcasting isn' t possible.
357357 #[ track_caller]
358358 pub fn scaled_add < S2 , E > ( & mut self , alpha : A , rhs : & ArrayBase < S2 , E > )
359359 where
@@ -1067,3 +1067,129 @@ mod blas_tests
10671067 }
10681068 }
10691069}
1070+
1071+ impl < A , S , S2 > Dot < ArrayBase < S2 , IxDyn > > for ArrayBase < S , IxDyn >
1072+ where
1073+ S : Data < Elem = A > ,
1074+ S2 : Data < Elem = A > ,
1075+ A : LinalgScalar ,
1076+ {
1077+ type Output = Array < A , IxDyn > ;
1078+
1079+ fn dot ( & self , rhs : & ArrayBase < S2 , IxDyn > ) -> Self :: Output {
1080+ match ( self . ndim ( ) , rhs. ndim ( ) ) {
1081+ ( 1 , 1 ) => {
1082+ // Vector-vector dot product
1083+ if self . len ( ) != rhs. len ( ) {
1084+ panic ! ( "Vector lengths must match for dot product" ) ;
1085+ }
1086+ let a = self . view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
1087+ let b = rhs. view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
1088+ let result = a. dot ( & b) ;
1089+ ArrayD :: from_elem ( vec ! [ ] , result)
1090+ }
1091+ ( 2 , 2 ) => {
1092+ // Matrix-matrix multiplication
1093+ let a = self . view ( ) . into_dimensionality :: < Ix2 > ( ) . unwrap ( ) ;
1094+ let b = rhs. view ( ) . into_dimensionality :: < Ix2 > ( ) . unwrap ( ) ;
1095+ let result = a. dot ( & b) ;
1096+ result. into_dimensionality :: < IxDyn > ( ) . unwrap ( )
1097+ }
1098+ ( 2 , 1 ) => {
1099+ // Matrix-vector multiplication
1100+ let a = self . view ( ) . into_dimensionality :: < Ix2 > ( ) . unwrap ( ) ;
1101+ let b = rhs. view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
1102+ let result = a. dot ( & b) ;
1103+ result. into_dimensionality :: < IxDyn > ( ) . unwrap ( )
1104+ }
1105+ ( 1 , 2 ) => {
1106+ // Vector-matrix multiplication
1107+ let a = self . view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
1108+ let b = rhs. view ( ) . into_dimensionality :: < Ix2 > ( ) . unwrap ( ) ;
1109+ let result = a. dot ( & b) ;
1110+ result. into_dimensionality :: < IxDyn > ( ) . unwrap ( )
1111+ }
1112+ _ => panic ! ( "Dot product for ArrayD is only supported for 1D and 2D arrays" ) ,
1113+ }
1114+ }
1115+ }
1116+
1117+ #[ cfg( test) ]
1118+ mod arrayd_dot_tests {
1119+ use super :: * ;
1120+ use crate :: ArrayD ;
1121+
1122+ #[ test]
1123+ fn test_arrayd_dot_2d ( ) {
1124+ // Test case from the original issue
1125+ let mat1 = ArrayD :: from_shape_vec ( vec ! [ 3 , 2 ] , vec ! [ 3.0 ; 6 ] ) . unwrap ( ) ;
1126+ let mat2 = ArrayD :: from_shape_vec ( vec ! [ 2 , 3 ] , vec ! [ 1.0 ; 6 ] ) . unwrap ( ) ;
1127+
1128+ let result = mat1. dot ( & mat2) ;
1129+
1130+ // Verify the result is correct
1131+ assert_eq ! ( result. ndim( ) , 2 ) ;
1132+ assert_eq ! ( result. shape( ) , & [ 3 , 3 ] ) ;
1133+
1134+ // Compare with Array2 implementation
1135+ let mat1_2d = Array2 :: from_shape_vec ( ( 3 , 2 ) , vec ! [ 3.0 ; 6 ] ) . unwrap ( ) ;
1136+ let mat2_2d = Array2 :: from_shape_vec ( ( 2 , 3 ) , vec ! [ 1.0 ; 6 ] ) . unwrap ( ) ;
1137+ let expected = mat1_2d. dot ( & mat2_2d) ;
1138+
1139+ assert_eq ! ( result. into_dimensionality:: <Ix2 >( ) . unwrap( ) , expected) ;
1140+ }
1141+
1142+ #[ test]
1143+ fn test_arrayd_dot_1d ( ) {
1144+ // Test 1D array dot product
1145+ let vec1 = ArrayD :: from_shape_vec ( vec ! [ 3 ] , vec ! [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
1146+ let vec2 = ArrayD :: from_shape_vec ( vec ! [ 3 ] , vec ! [ 4.0 , 5.0 , 6.0 ] ) . unwrap ( ) ;
1147+
1148+ let result = vec1. dot ( & vec2) ;
1149+
1150+ // Verify scalar result
1151+ assert_eq ! ( result. ndim( ) , 0 ) ;
1152+ assert_eq ! ( result. shape( ) , & [ ] ) ;
1153+ assert_eq ! ( result[ [ ] ] , 32.0 ) ; // 1*4 + 2*5 + 3*6
1154+ }
1155+
1156+ #[ test]
1157+ #[ should_panic( expected = "Dot product for ArrayD is only supported for 1D and 2D arrays" ) ]
1158+ fn test_arrayd_dot_3d ( ) {
1159+ // Test that 3D arrays are not supported
1160+ let arr1 = ArrayD :: from_shape_vec ( vec ! [ 2 , 2 , 2 ] , vec ! [ 1.0 ; 8 ] ) . unwrap ( ) ;
1161+ let arr2 = ArrayD :: from_shape_vec ( vec ! [ 2 , 2 , 2 ] , vec ! [ 1.0 ; 8 ] ) . unwrap ( ) ;
1162+
1163+ let _result = arr1. dot ( & arr2) ; // Should panic
1164+ }
1165+
1166+ #[ test]
1167+ #[ should_panic( expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication" ) ]
1168+ fn test_arrayd_dot_incompatible_dims ( ) {
1169+ // Test arrays with incompatible dimensions
1170+ let arr1 = ArrayD :: from_shape_vec ( vec ! [ 2 , 3 ] , vec ! [ 1.0 ; 6 ] ) . unwrap ( ) ;
1171+ let arr2 = ArrayD :: from_shape_vec ( vec ! [ 4 , 5 ] , vec ! [ 1.0 ; 20 ] ) . unwrap ( ) ;
1172+
1173+ let _result = arr1. dot ( & arr2) ; // Should panic
1174+ }
1175+
1176+ #[ test]
1177+ fn test_arrayd_dot_matrix_vector ( ) {
1178+ // Test matrix-vector multiplication
1179+ let mat = ArrayD :: from_shape_vec ( vec ! [ 3 , 2 ] , vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ) . unwrap ( ) ;
1180+ let vec = ArrayD :: from_shape_vec ( vec ! [ 2 ] , vec ! [ 1.0 , 2.0 ] ) . unwrap ( ) ;
1181+
1182+ let result = mat. dot ( & vec) ;
1183+
1184+ // Verify result
1185+ assert_eq ! ( result. ndim( ) , 1 ) ;
1186+ assert_eq ! ( result. shape( ) , & [ 3 ] ) ;
1187+
1188+ // Compare with Array2 implementation
1189+ let mat_2d = Array2 :: from_shape_vec ( ( 3 , 2 ) , vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ) . unwrap ( ) ;
1190+ let vec_1d = Array1 :: from_vec ( vec ! [ 1.0 , 2.0 ] ) ;
1191+ let expected = mat_2d. dot ( & vec_1d) ;
1192+
1193+ assert_eq ! ( result. into_dimensionality:: <Ix1 >( ) . unwrap( ) , expected) ;
1194+ }
1195+ }
0 commit comments