11// Still working on that
2- // Probably going to implement, either using Linfa_tsne or another approach
2+ // Probably going to implement, either using Linfa_tsne or another approach
3+
4+ use ndarray:: { Array2 , ArrayD , ArrayViewD } ;
5+ use single_utilities:: traits:: FloatOpsTS ;
6+
7+ pub struct TSNEConfig {
8+ output_dim : u8 ,
9+ perplexity : f32 ,
10+ epochs : usize ,
11+ theta : f32 ,
12+ }
13+
14+ pub fn run_f32 < T : FloatOpsTS > (
15+ x : ArrayViewD < f32 > ,
16+ config : TSNEConfig ,
17+ ) -> anyhow:: Result < ArrayD < f32 > > {
18+ let n_obs = x. shape ( ) [ 0 ] ;
19+ let n_dim = x. shape ( ) [ 1 ] ;
20+ let x_slice = x. as_slice ( ) . unwrap ( ) ;
21+
22+ let x_chunked_slice: Vec < & [ f32 ] > = x_slice. chunks ( n_dim) . collect ( ) ;
23+ let tsne_result = bhtsne:: tSNE:: new ( & x_chunked_slice)
24+ . embedding_dim ( config. output_dim )
25+ . perplexity ( config. perplexity )
26+ . epochs ( config. epochs )
27+ . barnes_hut ( config. theta , |sample_a, sample_b| {
28+ sample_a
29+ . iter ( )
30+ . zip ( sample_b. iter ( ) )
31+ . map ( |( & a, & b) | num_traits:: Float :: powi ( a - b, 2 ) )
32+ . sum :: < f32 > ( )
33+ . sqrt ( )
34+ } )
35+ . embedding ( ) ;
36+
37+ let result = Array2 :: from_shape_vec ( ( n_obs, config. output_dim as usize ) , tsne_result) ?;
38+ Ok ( result. into_dyn ( ) )
39+ }
40+
41+ pub fn run_f64 < T : FloatOpsTS > (
42+ x : ArrayViewD < f64 > ,
43+ config : TSNEConfig ,
44+ ) -> anyhow:: Result < ArrayD < f64 > > {
45+ let n_obs = x. shape ( ) [ 0 ] ;
46+ let n_dim = x. shape ( ) [ 1 ] ;
47+ let x_slice = x. as_slice ( ) . unwrap ( ) ;
48+
49+ let x_chunked_slice: Vec < & [ f64 ] > = x_slice. chunks ( n_dim) . collect ( ) ;
50+ let tsne_result = bhtsne:: tSNE:: new ( & x_chunked_slice)
51+ . embedding_dim ( config. output_dim )
52+ . perplexity ( config. perplexity as f64 )
53+ . epochs ( config. epochs )
54+ . barnes_hut ( config. theta as f64 , |sample_a, sample_b| {
55+ sample_a
56+ . iter ( )
57+ . zip ( sample_b. iter ( ) )
58+ . map ( |( & a, & b) | num_traits:: Float :: powi ( a - b, 2 ) )
59+ . sum :: < f64 > ( )
60+ . sqrt ( )
61+ } )
62+ . embedding ( ) ;
63+
64+ let result = Array2 :: from_shape_vec ( ( n_obs, config. output_dim as usize ) , tsne_result) ?;
65+ Ok ( result. into_dyn ( ) )
66+ }
0 commit comments