@@ -1109,7 +1109,8 @@ where
11091109{
11101110 type Output = Array < A , IxDyn > ;
11111111
1112- fn dot ( & self , rhs : & ArrayBase < S2 , IxDyn > ) -> Self :: Output {
1112+ fn dot ( & self , rhs : & ArrayBase < S2 , IxDyn > ) -> Self :: Output
1113+ {
11131114 match ( self . ndim ( ) , rhs. ndim ( ) ) {
11141115 ( 1 , 1 ) => {
11151116 let a = self . view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
@@ -1144,38 +1145,41 @@ where
11441145}
11451146
11461147#[ cfg( test) ]
1147- mod arrayd_dot_tests {
1148+ mod arrayd_dot_tests
1149+ {
11481150 use super :: * ;
11491151 use crate :: ArrayD ;
11501152
11511153 #[ test]
1152- fn test_arrayd_dot_2d ( ) {
1154+ fn test_arrayd_dot_2d ( )
1155+ {
11531156 // Test case from the original issue
11541157 let mat1 = ArrayD :: from_shape_vec ( vec ! [ 3 , 2 ] , vec ! [ 3.0 ; 6 ] ) . unwrap ( ) ;
11551158 let mat2 = ArrayD :: from_shape_vec ( vec ! [ 2 , 3 ] , vec ! [ 1.0 ; 6 ] ) . unwrap ( ) ;
1156-
1159+
11571160 let result = mat1. dot ( & mat2) ;
1158-
1161+
11591162 // Verify the result is correct
11601163 assert_eq ! ( result. ndim( ) , 2 ) ;
11611164 assert_eq ! ( result. shape( ) , & [ 3 , 3 ] ) ;
1162-
1165+
11631166 // Compare with Array2 implementation
11641167 let mat1_2d = Array2 :: from_shape_vec ( ( 3 , 2 ) , vec ! [ 3.0 ; 6 ] ) . unwrap ( ) ;
11651168 let mat2_2d = Array2 :: from_shape_vec ( ( 2 , 3 ) , vec ! [ 1.0 ; 6 ] ) . unwrap ( ) ;
11661169 let expected = mat1_2d. dot ( & mat2_2d) ;
1167-
1170+
11681171 assert_eq ! ( result. into_dimensionality:: <Ix2 >( ) . unwrap( ) , expected) ;
11691172 }
11701173
11711174 #[ test]
1172- fn test_arrayd_dot_1d ( ) {
1175+ fn test_arrayd_dot_1d ( )
1176+ {
11731177 // Test 1D array dot product
11741178 let vec1 = ArrayD :: from_shape_vec ( vec ! [ 3 ] , vec ! [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
11751179 let vec2 = ArrayD :: from_shape_vec ( vec ! [ 3 ] , vec ! [ 4.0 , 5.0 , 6.0 ] ) . unwrap ( ) ;
1176-
1180+
11771181 let result = vec1. dot ( & vec2) ;
1178-
1182+
11791183 // Verify scalar result
11801184 assert_eq ! ( result. ndim( ) , 0 ) ;
11811185 assert_eq ! ( result. shape( ) , & [ ] ) ;
@@ -1184,41 +1188,44 @@ mod arrayd_dot_tests {
11841188
11851189 #[ test]
11861190 #[ should_panic( expected = "Dot product for ArrayD is only supported for 1D and 2D arrays" ) ]
1187- fn test_arrayd_dot_3d ( ) {
1191+ fn test_arrayd_dot_3d ( )
1192+ {
11881193 // Test that 3D arrays are not supported
11891194 let arr1 = ArrayD :: from_shape_vec ( vec ! [ 2 , 2 , 2 ] , vec ! [ 1.0 ; 8 ] ) . unwrap ( ) ;
11901195 let arr2 = ArrayD :: from_shape_vec ( vec ! [ 2 , 2 , 2 ] , vec ! [ 1.0 ; 8 ] ) . unwrap ( ) ;
1191-
1196+
11921197 let _result = arr1. dot ( & arr2) ; // Should panic
11931198 }
11941199
11951200 #[ test]
11961201 #[ should_panic( expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication" ) ]
1197- fn test_arrayd_dot_incompatible_dims ( ) {
1202+ fn test_arrayd_dot_incompatible_dims ( )
1203+ {
11981204 // Test arrays with incompatible dimensions
11991205 let arr1 = ArrayD :: from_shape_vec ( vec ! [ 2 , 3 ] , vec ! [ 1.0 ; 6 ] ) . unwrap ( ) ;
12001206 let arr2 = ArrayD :: from_shape_vec ( vec ! [ 4 , 5 ] , vec ! [ 1.0 ; 20 ] ) . unwrap ( ) ;
1201-
1207+
12021208 let _result = arr1. dot ( & arr2) ; // Should panic
12031209 }
12041210
12051211 #[ test]
1206- fn test_arrayd_dot_matrix_vector ( ) {
1212+ fn test_arrayd_dot_matrix_vector ( )
1213+ {
12071214 // Test matrix-vector multiplication
12081215 let mat = ArrayD :: from_shape_vec ( vec ! [ 3 , 2 ] , vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ) . unwrap ( ) ;
12091216 let vec = ArrayD :: from_shape_vec ( vec ! [ 2 ] , vec ! [ 1.0 , 2.0 ] ) . unwrap ( ) ;
1210-
1217+
12111218 let result = mat. dot ( & vec) ;
1212-
1219+
12131220 // Verify result
12141221 assert_eq ! ( result. ndim( ) , 1 ) ;
12151222 assert_eq ! ( result. shape( ) , & [ 3 ] ) ;
1216-
1223+
12171224 // Compare with Array2 implementation
12181225 let mat_2d = Array2 :: from_shape_vec ( ( 3 , 2 ) , vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ) . unwrap ( ) ;
12191226 let vec_1d = Array1 :: from_vec ( vec ! [ 1.0 , 2.0 ] ) ;
12201227 let expected = mat_2d. dot ( & vec_1d) ;
1221-
1228+
12221229 assert_eq ! ( result. into_dimensionality:: <Ix1 >( ) . unwrap( ) , expected) ;
12231230 }
12241231}
0 commit comments