diff --git a/src/aspire/basis/fspca.py b/src/aspire/basis/fspca.py index 0375f9183..d2f78b2ff 100644 --- a/src/aspire/basis/fspca.py +++ b/src/aspire/basis/fspca.py @@ -15,7 +15,7 @@ class FSPCABasis(SteerableBasis2D): A class for Fast Steerable Principal Component Analaysis basis. FSPCA is an extension to Fourier Bessel representations - (provided asF BBasis2D/FFBBasis2D), which computes combinations of basis + (provided as FFBasis2D/FLEBasis2D), which computes combinations of basis coefficients coresponding to the princicpal components of image(s) represented in the provided basis. diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index f3898f3fe..30106909a 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -100,9 +100,9 @@ def __init__(self, basis): self.dtype = self.basis.dtype assert basis.ndim == 2, "Only two-dimensional basis functions are needed." - def _ctf_identity_mat(self): + def _identity_mat(self): """ - Returns CTF identity corresponding to the `matrix_type` of `self.basis`. + Returns identity corresponding to the `matrix_type` of `self.basis`. :return: Identity BlkDiagMatrix or DiagMatrix """ @@ -111,6 +111,17 @@ def _ctf_identity_mat(self): else: return BlkDiagMatrix.eye(self.basis.blk_diag_cov_shape, dtype=self.dtype) + def _zeros_mat(self): + """ + Returns zero initialized matrix according to the `matrix_type` of `self.basis`. + + :return: Zeros BlkDiagMatrix or DiagMatrix + """ + if self.basis.matrix_type == DiagMatrix: + return DiagMatrix.zeros(self.basis.count, dtype=self.dtype) + else: + return BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, dtype=self.dtype) + def _get_mean(self, coefs): """ Calculate the mean vector from the expansion coefficients of 2D images without CTF information. @@ -211,7 +222,7 @@ def get_mean(self, coefs, ctf_basis=None, ctf_idx=None): # should assert we require none or both... if (ctf_basis is None) or (ctf_idx is None): ctf_idx = np.zeros(coefs.shape[0], dtype=int) - ctf_basis = [self._ctf_identity_mat()] + ctf_basis = [self._identity_mat()] b = np.zeros(self.basis.count, dtype=coefs.dtype) @@ -272,7 +283,7 @@ def get_covar( if (ctf_basis is None) or (ctf_idx is None): ctf_idx = np.zeros(coefs.shape[0], dtype=int) - ctf_basis = [self._ctf_identity_mat()] + ctf_basis = [self._identity_mat()] def identity(x): return x @@ -529,7 +540,7 @@ def _build(self): logger.info("CTF filters are not included in Cov2D denoising") # set all CTF filters to an identity filter self.ctf_idx = np.zeros(src.n, dtype=int) - self.ctf_basis = [self._ctf_identity_mat()] + self.ctf_basis = [self._identity_mat()] else: logger.info("Represent CTF filters in basis") @@ -596,7 +607,7 @@ def _calc_op(self): A_mean = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, self.dtype) A_covar = [None for _ in ctf_basis] - M_covar = BlkDiagMatrix.zeros_like(A_mean) + M_covar = self._zeros_mat() for k in np.unique(ctf_idx): weight = float(np.count_nonzero(ctf_idx == k) / src.n) @@ -609,7 +620,6 @@ def _calc_op(self): A_mean += A_mean_k A_covar_k = np.sqrt(weight).astype(self.dtype) * ctf_basis_k_sq A_covar[k] = A_covar_k - M_covar += A_covar_k self.A_mean = A_mean @@ -671,13 +681,20 @@ def _solve_covar(self, A_covar, b_covar, M, covar_est_opt): def _solve_covar_direct(self, A_covar, b_covar, M, covar_est_opt): # A_covar is a list of DiagMatrix, representing each ctf in self.basis. # b_covar is a BlkDiagMatrix - # M is sum of weighted A squared, only used for cg, ignore here. - A_covar = DiagMatrix(np.concatenate([x.asnumpy() for x in A_covar])) - A2i = A_covar * A_covar + # M is sum of weighted A squared. + # covar_est_opt ignored + + # Because its cheap to compute for DiagMatrix, we'll log the conditioning here. + logger.debug(f"M condition: {M.condition()}") + + # Compute inverse + Minv = M.invert() - res = BlkDiagMatrix.empty(b_covar.nblocks, self.dtype) - for b in range(b_covar.nblocks): - res.data[b] = b_covar[b] / A2i[b] + # The combined left right scaling here is equivalent to + # looping & building dense array blks of the diagonals cross + # multiplied to be used for elementwise division as was done + # in Yunpeng's code. + res = Minv @ b_covar @ Minv return res @@ -788,7 +805,7 @@ def identity(x): if not self.b_covar: self._calc_rhs() - if not self.A_covar or self.M_covar: + if not self.A_covar or (self.M_covar is not None): self._calc_op() if mean_coef is None: diff --git a/src/aspire/operators/diag_matrix.py b/src/aspire/operators/diag_matrix.py index 4c94ab83d..106a7def5 100644 --- a/src/aspire/operators/diag_matrix.py +++ b/src/aspire/operators/diag_matrix.py @@ -414,6 +414,15 @@ def __pow__(self, val): return self.pow(val) + def invert(self): + """ + Return `DiagMatrix` instance containing reciprocal elements, + representing the mathematical inverse. + + :return: `DiagMatrix` instance. + """ + return DiagMatrix(1 / self._data) + @property def norm(self): """ @@ -508,6 +517,14 @@ def eigvals(self): """ return self.asnumpy() + def condition(self): + """ + Return the condition number of this matrix. + + :return: Condition number as a float + """ + return np.max(self.asnumpy()) / np.min(self.asnumpy()) + @staticmethod def empty(shape, dtype=np.float32): """ diff --git a/tests/test_diag_matrix.py b/tests/test_diag_matrix.py index 05805912c..8b99029c7 100644 --- a/tests/test_diag_matrix.py +++ b/tests/test_diag_matrix.py @@ -488,6 +488,26 @@ def test_pow(diag_matrix_fixture): np.testing.assert_allclose(d1.pow(2), ref) +def test_invert(diag_matrix_fixture): + """ + Test inversion + """ + d1, _, d_np = diag_matrix_fixture + + ref = 1 / d_np[0] + np.testing.assert_allclose(d1.invert(), ref) + + +def test_condition(diag_matrix_fixture): + """ + Test condition number method + """ + d1, _, d_np = diag_matrix_fixture + + ref = np.max(d_np[0]) / np.min(d_np[0]) + np.testing.assert_allclose(d1.condition(), ref) + + def test_norm(diag_matrix_fixture): """ Test the `norm` compared to Numpy.