Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/aspire/basis/fspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class FSPCABasis(SteerableBasis2D):
A class for Fast Steerable Principal Component Analaysis basis.

FSPCA is an extension to Fourier Bessel representations
(provided asF BBasis2D/FFBBasis2D), which computes combinations of basis
(provided as FFBasis2D/FLEBasis2D), which computes combinations of basis
coefficients coresponding to the princicpal components of image(s)
represented in the provided basis.

Expand Down
45 changes: 31 additions & 14 deletions src/aspire/covariance/covar2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def __init__(self, basis):
self.dtype = self.basis.dtype
assert basis.ndim == 2, "Only two-dimensional basis functions are needed."

def _ctf_identity_mat(self):
def _identity_mat(self):
"""
Returns CTF identity corresponding to the `matrix_type` of `self.basis`.
Returns identity corresponding to the `matrix_type` of `self.basis`.

:return: Identity BlkDiagMatrix or DiagMatrix
"""
Expand All @@ -111,6 +111,17 @@ def _ctf_identity_mat(self):
else:
return BlkDiagMatrix.eye(self.basis.blk_diag_cov_shape, dtype=self.dtype)

def _zeros_mat(self):
"""
Returns zero initialized matrix according to the `matrix_type` of `self.basis`.

:return: Zeros BlkDiagMatrix or DiagMatrix
"""
if self.basis.matrix_type == DiagMatrix:
return DiagMatrix.zeros(self.basis.count, dtype=self.dtype)
else:
return BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, dtype=self.dtype)

def _get_mean(self, coefs):
"""
Calculate the mean vector from the expansion coefficients of 2D images without CTF information.
Expand Down Expand Up @@ -211,7 +222,7 @@ def get_mean(self, coefs, ctf_basis=None, ctf_idx=None):
# should assert we require none or both...
if (ctf_basis is None) or (ctf_idx is None):
ctf_idx = np.zeros(coefs.shape[0], dtype=int)
ctf_basis = [self._ctf_identity_mat()]
ctf_basis = [self._identity_mat()]

b = np.zeros(self.basis.count, dtype=coefs.dtype)

Expand Down Expand Up @@ -272,7 +283,7 @@ def get_covar(

if (ctf_basis is None) or (ctf_idx is None):
ctf_idx = np.zeros(coefs.shape[0], dtype=int)
ctf_basis = [self._ctf_identity_mat()]
ctf_basis = [self._identity_mat()]

def identity(x):
return x
Expand Down Expand Up @@ -529,7 +540,7 @@ def _build(self):
logger.info("CTF filters are not included in Cov2D denoising")
# set all CTF filters to an identity filter
self.ctf_idx = np.zeros(src.n, dtype=int)
self.ctf_basis = [self._ctf_identity_mat()]
self.ctf_basis = [self._identity_mat()]

else:
logger.info("Represent CTF filters in basis")
Expand Down Expand Up @@ -596,7 +607,7 @@ def _calc_op(self):

A_mean = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, self.dtype)
A_covar = [None for _ in ctf_basis]
M_covar = BlkDiagMatrix.zeros_like(A_mean)
M_covar = self._zeros_mat()

for k in np.unique(ctf_idx):
weight = float(np.count_nonzero(ctf_idx == k) / src.n)
Expand All @@ -609,7 +620,6 @@ def _calc_op(self):
A_mean += A_mean_k
A_covar_k = np.sqrt(weight).astype(self.dtype) * ctf_basis_k_sq
A_covar[k] = A_covar_k

M_covar += A_covar_k

self.A_mean = A_mean
Expand Down Expand Up @@ -671,13 +681,20 @@ def _solve_covar(self, A_covar, b_covar, M, covar_est_opt):
def _solve_covar_direct(self, A_covar, b_covar, M, covar_est_opt):
# A_covar is a list of DiagMatrix, representing each ctf in self.basis.
# b_covar is a BlkDiagMatrix
# M is sum of weighted A squared, only used for cg, ignore here.
A_covar = DiagMatrix(np.concatenate([x.asnumpy() for x in A_covar]))
A2i = A_covar * A_covar
# M is sum of weighted A squared.
# covar_est_opt ignored

# Because its cheap to compute for DiagMatrix, we'll log the conditioning here.
logger.debug(f"M condition: {M.condition()}")

# Compute inverse
Minv = M.invert()

res = BlkDiagMatrix.empty(b_covar.nblocks, self.dtype)
for b in range(b_covar.nblocks):
res.data[b] = b_covar[b] / A2i[b]
# The combined left right scaling here is equivalent to
# looping & building dense array blks of the diagonals cross
# multiplied to be used for elementwise division as was done
# in Yunpeng's code.
res = Minv @ b_covar @ Minv

return res

Expand Down Expand Up @@ -788,7 +805,7 @@ def identity(x):
if not self.b_covar:
self._calc_rhs()

if not self.A_covar or self.M_covar:
if not self.A_covar or (self.M_covar is not None):
self._calc_op()

if mean_coef is None:
Expand Down
17 changes: 17 additions & 0 deletions src/aspire/operators/diag_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,15 @@ def __pow__(self, val):

return self.pow(val)

def invert(self):
"""
Return `DiagMatrix` instance containing reciprocal elements,
representing the mathematical inverse.

:return: `DiagMatrix` instance.
"""
return DiagMatrix(1 / self._data)

@property
def norm(self):
"""
Expand Down Expand Up @@ -508,6 +517,14 @@ def eigvals(self):
"""
return self.asnumpy()

def condition(self):
"""
Return the condition number of this matrix.

:return: Condition number as a float
"""
return np.max(self.asnumpy()) / np.min(self.asnumpy())

@staticmethod
def empty(shape, dtype=np.float32):
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/test_diag_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,26 @@ def test_pow(diag_matrix_fixture):
np.testing.assert_allclose(d1.pow(2), ref)


def test_invert(diag_matrix_fixture):
"""
Test inversion
"""
d1, _, d_np = diag_matrix_fixture

ref = 1 / d_np[0]
np.testing.assert_allclose(d1.invert(), ref)


def test_condition(diag_matrix_fixture):
"""
Test condition number method
"""
d1, _, d_np = diag_matrix_fixture

ref = np.max(d_np[0]) / np.min(d_np[0])
np.testing.assert_allclose(d1.condition(), ref)


def test_norm(diag_matrix_fixture):
"""
Test the `norm` compared to Numpy.
Expand Down
Loading