Skip to content

Commit 3eb7db2

Browse files
authored
Merge pull request #32 from SingleRust/main
Upgraded sodlib performance
2 parents 897fc3c + e6d2ff5 commit 3eb7db2

5 files changed

Lines changed: 505 additions & 197 deletions

File tree

Cargo.lock

Lines changed: 12 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "single_algebra"
3-
version = "0.7.0"
3+
version = "0.8.3"
44
edition = "2021"
55
license-file = "LICENSE.md"
66
description = "A linear algebra convenience library for the single-rust library. Can be used externally as well."
@@ -37,7 +37,7 @@ simba = { version = "0.9.0", optional = true }
3737
smartcore = { version = "0.4", features = ["ndarray-bindings"], optional = true }
3838
single-svdlib = { version = "1.0.4" }
3939
rand = "0.9.0"
40-
single-utilities = "0.6.0"
40+
single-utilities = "0.7.0"
4141

4242
[dev-dependencies]
4343
criterion = "0.5.1"

src/sparse/csc.rs

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ use nalgebra_sparse::CscMatrix;
22
use num_traits::{Float, NumCast, PrimInt, Unsigned, Zero};
33
use single_utilities::types::Direction;
44
use std::collections::{HashMap, HashSet};
5-
use std::hash::Hash;
65
use std::iter::Sum;
7-
use std::ops::Add;
86
use std::ops::AddAssign;
97

8+
use crate::sparse::MatrixNTop;
109
use crate::utils::Normalize;
1110

1211
use super::{
@@ -359,8 +358,8 @@ where
359358

360359
fn var_col<I, T>(&self) -> anyhow::Result<Vec<T>>
361360
where
362-
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
363-
T: Float + NumCast + AddAssign + std::iter::Sum,
361+
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
362+
T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync,
364363
Self::Item: NumCast,
365364
{
366365
let sum: Vec<T> = self.sum_col()?;
@@ -385,8 +384,8 @@ where
385384

386385
fn var_row<I, T>(&self) -> anyhow::Result<Vec<T>>
387386
where
388-
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
389-
T: Float + NumCast + AddAssign + std::iter::Sum,
387+
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
388+
T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync,
390389
Self::Item: NumCast,
391390
{
392391
let sum: Vec<T> = self.sum_row()?;
@@ -410,8 +409,8 @@ where
410409

411410
fn var_col_chunk<I, T>(&self, reference: &mut [T]) -> anyhow::Result<()>
412411
where
413-
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
414-
T: Float + NumCast + AddAssign + std::iter::Sum,
412+
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
413+
T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync,
415414
Self::Item: NumCast,
416415
{
417416
// Validate input slice length matches number of columns
@@ -449,8 +448,8 @@ where
449448

450449
fn var_row_chunk<I, T>(&self, reference: &mut [T]) -> anyhow::Result<()>
451450
where
452-
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
453-
T: Float + NumCast + AddAssign + std::iter::Sum,
451+
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
452+
T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync,
454453
Self::Item: NumCast,
455454
{
456455
// Validate input slice length matches number of rows
@@ -488,8 +487,8 @@ where
488487

489488
fn var_col_masked<I, T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
490489
where
491-
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
492-
T: Float + NumCast + AddAssign + Sum,
490+
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
491+
T: Float + NumCast + AddAssign + Sum + Send + Sync,
493492
{
494493
// Validate mask length
495494
if mask.len() < self.nrows() {
@@ -537,8 +536,8 @@ where
537536

538537
fn var_row_masked<I, T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
539538
where
540-
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
541-
T: Float + NumCast + AddAssign + Sum,
539+
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
540+
T: Float + NumCast + AddAssign + Sum + Send + Sync
542541
{
543542
// Validate mask length
544543
if mask.len() < self.ncols() {
@@ -590,7 +589,7 @@ impl<M: NumCast + Copy + PartialOrd + NumericOps> MatrixMinMax for CscMatrix<M>
590589

591590
fn min_max_col<Item>(&self) -> anyhow::Result<(Vec<Item>, Vec<Item>)>
592591
where
593-
Item: NumCast + Copy + PartialOrd + NumericOps,
592+
Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync,
594593
{
595594
let mut min: Vec<Item> = vec![Item::max_value(); self.ncols()];
596595
let mut max: Vec<Item> = vec![Item::min_value(); self.ncols()];
@@ -601,7 +600,7 @@ impl<M: NumCast + Copy + PartialOrd + NumericOps> MatrixMinMax for CscMatrix<M>
601600

602601
fn min_max_row<Item>(&self) -> anyhow::Result<(Vec<Item>, Vec<Item>)>
603602
where
604-
Item: NumCast + Copy + PartialOrd + NumericOps,
603+
Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync,
605604
{
606605
let mut min: Vec<Item> = vec![Item::max_value(); self.nrows()];
607606
let mut max: Vec<Item> = vec![Item::min_value(); self.nrows()];
@@ -1027,6 +1026,41 @@ impl<M: NumericOps + NumCast> BatchMatrixMean for CscMatrix<M> {
10271026
}
10281027
}
10291028

1029+
impl<M: NumericOps + NumCast> MatrixNTop for CscMatrix<M> {
1030+
type Item = M;
1031+
1032+
fn sum_row_n_top<T>(&self, n: usize) -> anyhow::Result<Vec<T>>
1033+
where
1034+
T: Float + NumCast + AddAssign + Sum {
1035+
let mut result = vec![T::zero(); self.nrows()];
1036+
1037+
let mut row_values: Vec<Vec<T>> = vec![Vec::new(); self.nrows()];
1038+
1039+
for col_idx in 0..self.ncols() {
1040+
let col_start = self.col_offsets()[col_idx];
1041+
let col_end = self.col_offsets()[col_idx + 1];
1042+
1043+
for idx in col_start..col_end {
1044+
let row_idx = self.row_indices()[idx];
1045+
if let Some(val) = T::from(self.values()[idx]) {
1046+
row_values[row_idx].push(val);
1047+
}
1048+
}
1049+
}
1050+
1051+
for (row_idx, mut values) in row_values.into_iter().enumerate() {
1052+
if values.len() <= n {
1053+
result[row_idx] = values.into_iter().sum();
1054+
} else {
1055+
values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1056+
result[row_idx] = values.into_iter().take(n).sum();
1057+
}
1058+
}
1059+
1060+
Ok(result)
1061+
}
1062+
}
1063+
10301064
#[cfg(test)]
10311065
mod tests {
10321066
use Direction;

0 commit comments

Comments
 (0)