@@ -6,10 +6,11 @@ use crate::network::Network;
66use crate :: similarity:: SimilarityMeasure ;
77use kiddo:: { KdTree , SquaredEuclidean } ;
88use nalgebra_sparse:: { CooMatrix , CsrMatrix } ;
9- use ndarray:: Array2 ;
9+ use ndarray:: { Array2 , ArrayD , ArrayViewD } ;
1010use num_traits:: { Float , FromPrimitive , ToPrimitive } ;
1111use petgraph:: graph:: UnGraph ;
1212use std:: collections:: HashSet ;
13+ use kiddo:: traits:: DistanceMetric ;
1314
1415pub fn create_similarity_network < T , S > (
1516 data : & Array2 < T > ,
5657 Network :: new_from_graph ( graph)
5758}
5859
60+ pub fn build_knn_network_combined_matrix_arrayd < T , const K : usize , D > (
61+ data : ArrayViewD < T > ,
62+ k : u64 ,
63+ ) -> anyhow:: Result < CsrMatrix < T > >
64+ where
65+ T : Float + FromPrimitive + ToPrimitive + Send + Sync + Default + ' static ,
66+ T : num_traits:: float:: FloatCore ,
67+ T : std:: fmt:: Debug ,
68+ T : std:: ops:: AddAssign ,
69+ D : DistanceMetric < T , K >
70+ {
71+ if data. ndim ( ) != 2 {
72+ return Err ( anyhow:: anyhow!( "The input array has to be two dimensional in order to be used to build a knn network." ) ) ;
73+ }
74+
75+ let shape = data. shape ( ) ;
76+ let n_samples = shape[ 0 ] as u64 ;
77+ let n_features = shape[ 1 ] as u64 ;
78+
79+ if ( n_features as usize ) < K {
80+ return Err ( anyhow:: anyhow!( "The data must have at least K features in order to be used for building a knn network." ) )
81+ }
82+
83+ let mut kdtree: KdTree < T , K > = KdTree :: new ( ) ;
84+
85+ for i in 0 ..n_samples {
86+ let mut point_array = [ T :: zero ( ) ; K ] ;
87+ for j in 0 ..K {
88+ point_array[ j] = * data. get ( [ i as usize , j] ) . unwrap_or ( & T :: zero ( ) ) ;
89+ }
90+ kdtree. add ( & point_array, i) ;
91+ }
92+
93+ let mut triplets = Vec :: with_capacity ( ( n_samples * k) as usize ) ;
94+
95+ for i in 0 ..n_samples {
96+ let mut query_array = [ T :: zero ( ) ; K ] ;
97+ for j in 0 ..K {
98+ query_array[ j] = * data. get ( [ i as usize , j] ) . unwrap_or ( & T :: zero ( ) ) ;
99+ }
100+
101+ let neighbors = kdtree. nearest_n :: < D > ( & query_array, ( k+1 ) as usize ) ;
102+
103+ for neighbor in neighbors. iter ( ) . skip ( 1 ) {
104+ if i <= neighbor. item {
105+ let weight = ( -neighbor. distance . sqrt ( ) ) . exp ( ) ;
106+ triplets. push ( ( i as usize , neighbor. item as usize , weight) ) ;
107+ }
108+ }
109+ }
110+
111+ let coo = CooMatrix :: try_from_triplets (
112+ n_samples as usize ,
113+ n_samples as usize ,
114+ triplets. iter ( ) . map ( |& ( i, _, _) | i) . collect ( ) ,
115+ triplets. iter ( ) . map ( |& ( _, j, _) | j) . collect ( ) ,
116+ triplets. iter ( ) . map ( |& ( _, _, v) | v) . collect ( ) ,
117+ )
118+ . map_err ( |e| anyhow:: anyhow!( "Failed to create COO matrix: {}" , e) ) ?;
119+
120+ Ok ( CsrMatrix :: from ( & coo) )
121+ }
122+
123+
59124pub fn build_knn_network_combined_matrix < T , const K : usize > (
60125 data : & Array2 < T > ,
61126 k : u64 ,
0 commit comments