Skip to content

Commit 6fd7d93

Browse files
committed
Fix cov and add tests
1 parent 49d3632 commit 6fd7d93

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

scself/tests/test_correlation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import unittest
2+
3+
import numpy as np
4+
import scipy.sparse as sps
5+
6+
from scself.utils.correlation import cov
7+
8+
X = np.random.default_rng(100).random((500, 100))
9+
10+
11+
class TestCOV(unittest.TestCase):
12+
13+
def test_dense(self):
14+
15+
x_cov = cov(X)
16+
np.testing.assert_almost_equal(
17+
x_cov,
18+
np.cov(X.T),
19+
decimal=4
20+
)
21+
22+
def test_dense_axis1(self):
23+
24+
x_cov = cov(X, axis=1)
25+
np.testing.assert_almost_equal(
26+
x_cov,
27+
np.cov(X),
28+
decimal=4
29+
)
30+
31+
def test_sparse(self):
32+
33+
x_cov = cov(sps.csr_array(X))
34+
np.testing.assert_almost_equal(
35+
x_cov,
36+
np.cov(X.T),
37+
decimal=4
38+
)
39+
40+
41+
def test_sparse_axis1(self):
42+
43+
x_cov = cov(sps.csr_array(X), axis=1)
44+
np.testing.assert_almost_equal(
45+
x_cov,
46+
np.cov(X),
47+
decimal=4
48+
)

scself/utils/correlation.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,46 @@ def fit_transform(self, X):
2525
def cov(X, axis=0):
2626

2727
if sps.issparse(X):
28-
return cov_sparse(X)
28+
return cov_sparse(X, axis=axis)
2929

3030
# Center and get num rows
3131
avg, w_sum = np.average(X, axis=axis, weights=None, returned=True)
3232
w_sum = w_sum[0]
3333
X = X - (avg[None, :] if axis == 0 else avg[:, None])
3434

3535
# Gram matrix
36-
X = np.dot(X.T, X)
37-
X *= np.true_divide(1, w_sum)
36+
if axis == 0:
37+
X = np.dot(X.T, X)
38+
else:
39+
X = np.dot(X, X.T)
40+
41+
X *= np.true_divide(1, w_sum - 1)
3842

3943
return X
4044

4145
def cov_sparse(X, axis=0):
4246

43-
avg = X.mean(axis)
47+
axsum = X.sum(axis)
48+
w_sum = X.shape[axis]
4449

4550
# for spmatrix & sparray
4651
try:
47-
avg = avg.A1
52+
axsum = axsum.A1
4853
except AttributeError:
49-
avg = avg.ravel()
54+
axsum = axsum.ravel()
5055

51-
w_sum = X.shape[axis]
52-
X = dot(X.T, X, dense=True)
53-
X *= np.true_divide(1, w_sum)
56+
axsum = axsum.reshape(-1, 1).dot(axsum.reshape(1, -1))
57+
axsum /= w_sum
5458

55-
return X
59+
if axis == 0:
60+
X_cov = dot(X.T, X, dense=True)
61+
else:
62+
X_cov = dot(X, X.T, dense=True)
63+
64+
X_cov -= axsum
65+
X_cov /= (w_sum - 1)
66+
67+
return X_cov
5668

5769
def corrcoef(X, axis=0):
5870

0 commit comments

Comments
 (0)