@@ -2,11 +2,10 @@ use nalgebra_sparse::CscMatrix;
22use num_traits:: { Float , NumCast , PrimInt , Unsigned , Zero } ;
33use single_utilities:: types:: Direction ;
44use std:: collections:: { HashMap , HashSet } ;
5- use std:: hash:: Hash ;
65use std:: iter:: Sum ;
7- use std:: ops:: Add ;
86use std:: ops:: AddAssign ;
97
8+ use crate :: sparse:: MatrixNTop ;
109use crate :: utils:: Normalize ;
1110
1211use super :: {
@@ -359,8 +358,8 @@ where
359358
360359 fn var_col < I , T > ( & self ) -> anyhow:: Result < Vec < T > >
361360 where
362- I : PrimInt + Unsigned + Zero + AddAssign + Into < T > ,
363- T : Float + NumCast + AddAssign + std:: iter:: Sum ,
361+ I : PrimInt + Unsigned + Zero + AddAssign + Into < T > + Send + Sync ,
362+ T : Float + NumCast + AddAssign + std:: iter:: Sum + Send + Sync ,
364363 Self :: Item : NumCast ,
365364 {
366365 let sum: Vec < T > = self . sum_col ( ) ?;
@@ -385,8 +384,8 @@ where
385384
386385 fn var_row < I , T > ( & self ) -> anyhow:: Result < Vec < T > >
387386 where
388- I : PrimInt + Unsigned + Zero + AddAssign + Into < T > ,
389- T : Float + NumCast + AddAssign + std:: iter:: Sum ,
387+ I : PrimInt + Unsigned + Zero + AddAssign + Into < T > + Send + Sync ,
388+ T : Float + NumCast + AddAssign + std:: iter:: Sum + Send + Sync ,
390389 Self :: Item : NumCast ,
391390 {
392391 let sum: Vec < T > = self . sum_row ( ) ?;
@@ -410,8 +409,8 @@ where
410409
411410 fn var_col_chunk < I , T > ( & self , reference : & mut [ T ] ) -> anyhow:: Result < ( ) >
412411 where
413- I : PrimInt + Unsigned + Zero + AddAssign + Into < T > ,
414- T : Float + NumCast + AddAssign + std:: iter:: Sum ,
412+ I : PrimInt + Unsigned + Zero + AddAssign + Into < T > + Send + Sync ,
413+ T : Float + NumCast + AddAssign + std:: iter:: Sum + Send + Sync ,
415414 Self :: Item : NumCast ,
416415 {
417416 // Validate input slice length matches number of columns
@@ -449,8 +448,8 @@ where
449448
450449 fn var_row_chunk < I , T > ( & self , reference : & mut [ T ] ) -> anyhow:: Result < ( ) >
451450 where
452- I : PrimInt + Unsigned + Zero + AddAssign + Into < T > ,
453- T : Float + NumCast + AddAssign + std:: iter:: Sum ,
451+ I : PrimInt + Unsigned + Zero + AddAssign + Into < T > + Send + Sync ,
452+ T : Float + NumCast + AddAssign + std:: iter:: Sum + Send + Sync ,
454453 Self :: Item : NumCast ,
455454 {
456455 // Validate input slice length matches number of rows
@@ -488,8 +487,8 @@ where
488487
489488 fn var_col_masked < I , T > ( & self , mask : & [ bool ] ) -> anyhow:: Result < Vec < T > >
490489 where
491- I : PrimInt + Unsigned + Zero + AddAssign + Into < T > ,
492- T : Float + NumCast + AddAssign + Sum ,
490+ I : PrimInt + Unsigned + Zero + AddAssign + Into < T > + Send + Sync ,
491+ T : Float + NumCast + AddAssign + Sum + Send + Sync ,
493492 {
494493 // Validate mask length
495494 if mask. len ( ) < self . nrows ( ) {
@@ -537,8 +536,8 @@ where
537536
538537 fn var_row_masked < I , T > ( & self , mask : & [ bool ] ) -> anyhow:: Result < Vec < T > >
539538 where
540- I : PrimInt + Unsigned + Zero + AddAssign + Into < T > ,
541- T : Float + NumCast + AddAssign + Sum ,
539+ I : PrimInt + Unsigned + Zero + AddAssign + Into < T > + Send + Sync ,
540+ T : Float + NumCast + AddAssign + Sum + Send + Sync
542541 {
543542 // Validate mask length
544543 if mask. len ( ) < self . ncols ( ) {
@@ -590,7 +589,7 @@ impl<M: NumCast + Copy + PartialOrd + NumericOps> MatrixMinMax for CscMatrix<M>
590589
591590 fn min_max_col < Item > ( & self ) -> anyhow:: Result < ( Vec < Item > , Vec < Item > ) >
592591 where
593- Item : NumCast + Copy + PartialOrd + NumericOps ,
592+ Item : NumCast + Copy + PartialOrd + NumericOps + Send + Sync ,
594593 {
595594 let mut min: Vec < Item > = vec ! [ Item :: max_value( ) ; self . ncols( ) ] ;
596595 let mut max: Vec < Item > = vec ! [ Item :: min_value( ) ; self . ncols( ) ] ;
@@ -601,7 +600,7 @@ impl<M: NumCast + Copy + PartialOrd + NumericOps> MatrixMinMax for CscMatrix<M>
601600
602601 fn min_max_row < Item > ( & self ) -> anyhow:: Result < ( Vec < Item > , Vec < Item > ) >
603602 where
604- Item : NumCast + Copy + PartialOrd + NumericOps ,
603+ Item : NumCast + Copy + PartialOrd + NumericOps + Send + Sync ,
605604 {
606605 let mut min: Vec < Item > = vec ! [ Item :: max_value( ) ; self . nrows( ) ] ;
607606 let mut max: Vec < Item > = vec ! [ Item :: min_value( ) ; self . nrows( ) ] ;
@@ -1027,6 +1026,41 @@ impl<M: NumericOps + NumCast> BatchMatrixMean for CscMatrix<M> {
10271026 }
10281027}
10291028
1029+ impl < M : NumericOps + NumCast > MatrixNTop for CscMatrix < M > {
1030+ type Item = M ;
1031+
1032+ fn sum_row_n_top < T > ( & self , n : usize ) -> anyhow:: Result < Vec < T > >
1033+ where
1034+ T : Float + NumCast + AddAssign + Sum {
1035+ let mut result = vec ! [ T :: zero( ) ; self . nrows( ) ] ;
1036+
1037+ let mut row_values: Vec < Vec < T > > = vec ! [ Vec :: new( ) ; self . nrows( ) ] ;
1038+
1039+ for col_idx in 0 ..self . ncols ( ) {
1040+ let col_start = self . col_offsets ( ) [ col_idx] ;
1041+ let col_end = self . col_offsets ( ) [ col_idx + 1 ] ;
1042+
1043+ for idx in col_start..col_end {
1044+ let row_idx = self . row_indices ( ) [ idx] ;
1045+ if let Some ( val) = T :: from ( self . values ( ) [ idx] ) {
1046+ row_values[ row_idx] . push ( val) ;
1047+ }
1048+ }
1049+ }
1050+
1051+ for ( row_idx, mut values) in row_values. into_iter ( ) . enumerate ( ) {
1052+ if values. len ( ) <= n {
1053+ result[ row_idx] = values. into_iter ( ) . sum ( ) ;
1054+ } else {
1055+ values. sort_by ( |a, b| b. partial_cmp ( a) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
1056+ result[ row_idx] = values. into_iter ( ) . take ( n) . sum ( ) ;
1057+ }
1058+ }
1059+
1060+ Ok ( result)
1061+ }
1062+ }
1063+
10301064#[ cfg( test) ]
10311065mod tests {
10321066 use Direction ;
0 commit comments