diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 66674bc..660dde7 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -542,7 +542,9 @@ class ProbeOrthogonalizeIncoherentModesOptions(FeatureOptions): method: enums.OrthogonalizationMethods = enums.OrthogonalizationMethods.SVD """The method to use for incoherent_mode orthogonalization.""" - + + sort_by_occupancy: bool = False + """Keep the probes sorted so that mode with highest occupancy is the 0th shared mode""" @dataclasses.dataclass class ProbeOrthogonalizeOPRModesOptions(FeatureOptions): @@ -648,25 +650,35 @@ def get_non_data_fields(self) -> dict: @dataclasses.dataclass class SynthesisDictLearnProbeOptions(Options): - - d_mat: Union[ndarray, Tensor] = None + + enabled: bool = False + enabled_shared: bool = False + enabled_opr: bool = False + + thresholding_type_shared: str = 'hard' + thresholding_type_opr: str = 'hard' + """Choose between 'hard' or 'soft' thresholding.""" + + dictionary_matrix: Union[ndarray, Tensor] = None """The synthesis sparse dictionary matrix; contains the basis functions that will be used to represent the probe via the sparse code weights.""" - d_mat_conj_transpose: Union[ndarray, Tensor] = None - """Conjugate transpose of the synthesis sparse dictionary matrix.""" - - d_mat_pinv: Union[ndarray, Tensor] = None + dictionary_matrix_pinv: Union[ndarray, Tensor] = None """Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix.""" - probe_sparse_code: Union[ndarray, Tensor] = None - """Sparse code weights vector.""" + sparse_code_probe_shared: Union[ndarray, Tensor] = None + """Sparse code weights vector for the shared modes.""" - probe_sparse_code_nnz: float = None + sparse_code_probe_shared_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector probe_sparse_code.""" + the SHARED sparse code weights vector sparse_code_probe_shared.""" + + sparse_code_probe_opr: Union[ndarray, Tensor] = None + """Sparse code weights vector for the OPRs.""" - enabled: bool = False + sparse_code_probe_opr_nnz: float = None + """Number of non-zeros we will keep when enforcing sparsity constraint on + the OPR sparse code weights vector sparse_code_probe_opr.""" @dataclasses.dataclass class PositionCorrectionOptions(Options): @@ -869,6 +881,12 @@ class OPRModeWeightsOptions(ParameterOptions): A separate step size for eigenmode weight update. """ + use_optimal_update: bool = False + """ + We do not compute an optimal update step for OPR weights and eigenmodes + using the default method; this add the option to do this. + """ + def check(self, options: "task_options.PtychographyTaskOptions"): super().check(options) if self.optimizable: diff --git a/src/ptychi/api/options/lsqml.py b/src/ptychi/api/options/lsqml.py index 2a2f063..31c6364 100644 --- a/src/ptychi/api/options/lsqml.py +++ b/src/ptychi/api/options/lsqml.py @@ -100,6 +100,10 @@ class LSQMLObjectOptions(base.ObjectOptions): propagation always uses all probe modes regardless of this option. """ +@dataclasses.dataclass +class LSQMLProbeExperimentalOptions(base.Options): + sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions) + @dataclasses.dataclass class LSQMLProbeOptions(base.ProbeOptions): @@ -107,8 +111,9 @@ class LSQMLProbeOptions(base.ProbeOptions): """ A scaler for the solved optimal step size (beta_LSQ in PtychoShelves). """ - - + experimental: LSQMLProbeExperimentalOptions = dataclasses.field(default_factory=LSQMLProbeExperimentalOptions) + + @dataclasses.dataclass class LSQMLProbePositionOptions(base.ProbePositionOptions): pass diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index 86f76a3..1807b7b 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -169,7 +169,9 @@ def build_probe(self): ): self.probe = probe.DIPProbe(**kwargs) elif ( - isinstance(self.probe_options, api.options.PIEProbeOptions) + isinstance(self.probe_options, api.options.PIEProbeOptions) + or + isinstance(self.probe_options, api.options.LSQMLProbeOptions) ) and ( self.probe_options.experimental.sdl_probe_options.enabled ): diff --git a/src/ptychi/data_structures/opr_mode_weights.py b/src/ptychi/data_structures/opr_mode_weights.py index 7c0c032..8d90aac 100644 --- a/src/ptychi/data_structures/opr_mode_weights.py +++ b/src/ptychi/data_structures/opr_mode_weights.py @@ -97,6 +97,7 @@ def intensity_variation_optimization_enabled(self, epoch: int): def update_variable_probe( self, probe: "Probe", + adjoint_shift_probe_update_direction, # what do I do for type hint here? indices: Tensor, chi: Tensor, delta_p_i: Tensor, @@ -117,8 +118,14 @@ def update_variable_probe( probe.optimization_enabled(current_epoch) or (self.eigenmode_weight_optimization_enabled(current_epoch)) ): - self.update_opr_probe_modes_and_weights( - probe, indices, chi, delta_p_i, delta_p_hat, obj_patches, current_epoch + self.update_opr_probe_modes_and_weights(probe, + adjoint_shift_probe_update_direction, + indices, + chi, + delta_p_i, + delta_p_hat, + obj_patches, + current_epoch ) if self.intensity_variation_optimization_enabled(current_epoch): @@ -134,6 +141,7 @@ def update_variable_probe( def update_opr_probe_modes_and_weights( self, probe: "Probe", + adjoint_shift_probe_update_direction, # what do I do for type hint here? indices: Tensor, chi: Tensor, delta_p_i: Tensor, @@ -144,12 +152,12 @@ def update_opr_probe_modes_and_weights( """ Update the eigenmodes of the first incoherent mode of the probe, and update the OPR mode weights. - This implementation is adapted from PtychoShelves code (update_variable_probe.m) and has some - differences from Eq. 31 of Odstrcil (2018). + The default (for self.options.use_optimal_update = False) implementation below is adapted from + PtychoShelves code (update_variable_probe.m) and has some differences from Eq. 31 of Odstrcil (2018). """ probe_data = probe.data weights_data = self.data - + batch_size = len(delta_p_i) n_points_total = self.n_scan_points @@ -158,44 +166,165 @@ def update_opr_probe_modes_and_weights( if batch_size == 1: return - # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. - relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation - relax_v = self.options.update_relaxation - # Shape of delta_p_i: (batch_size, n_probe_modes, h, w) - # Use only the first incoherent mode - delta_p_i = delta_p_i[:, 0, :, :] - delta_p_hat = delta_p_hat[0, :, :] - residue_update = delta_p_i - delta_p_hat - - # Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode. - for i_opr_mode in range(1, probe.n_opr_modes): - # Just take the first incoherent mode. - eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode) - weights_i = self.get_weights(indices)[:, i_opr_mode] - eigenmode_i, weights_i = self._update_first_eigenmode_and_weight( - residue_update, - eigenmode_i, - weights_i, - relax_u, - relax_v, - obj_patches, - chi, - update_eigenmode=probe.optimization_enabled(current_epoch), - update_weights=self.eigenmode_weight_optimization_enabled(current_epoch), - ) - - # Project residue on this eigenmode, then subtract it. - if i_opr_mode < probe.n_opr_modes - 1: - residue_update = residue_update - pmath.project( - residue_update, eigenmode_i, dim=(-2, -1) + update_eigenmode = probe.optimization_enabled(current_epoch) # why is this needed again? To even get into this function, we need this to already be true? + update_eigenmode_weights = self.eigenmode_weight_optimization_enabled(current_epoch) + + if self.options.use_optimal_update: + + rc = obj_patches.shape[-2] * obj_patches.shape[-1] + n_spos = obj_patches.shape[0] + + U = probe_data[1:, 0, ...] + + Ws = (weights_data[ indices, 1:]).to(torch.complex64) + + Tsconj_chi = (obj_patches[:,0,...].conj() * chi[:,0,...]) + Tsconj_chi = adjoint_shift_probe_update_direction( indices, Tsconj_chi[:,None,...], first_mode_only=True) + + chi = adjoint_shift_probe_update_direction( indices, chi, first_mode_only=True) + + U = torch.reshape(U, (U.shape[0], rc)) + chi_vec = torch.reshape(chi[:,0,...], (n_spos, rc)) + Ts = torch.reshape(obj_patches[:,0,...], (n_spos, rc)) + Tsconj_chi = torch.reshape(Tsconj_chi[:,0,...], (n_spos, rc)).T + + # Optimal OPR weight updates + + if update_eigenmode_weights: + + delta_Ws = -2 * torch.real(U.conj() @ Tsconj_chi).to(torch.complex64) + + Ts_U_deltaWs = Ts.T * (U.T @ delta_Ws) + numer = torch.sum(torch.real(chi_vec * Ts_U_deltaWs.H)) + denom = torch.sum(torch.real( Ts_U_deltaWs.conj() * Ts_U_deltaWs )) + optimal_step_deltaWs = self.options.update_relaxation * (numer / denom) + + Ws = (Ws + optimal_step_deltaWs * delta_Ws.T) + + if (probe.representation == "sparse_code" + and probe.options.experimental.sdl_probe_options.enabled_opr): + + # Optimal sparse code OPR mode updates + + delta_U = -1 * Tsconj_chi @ Ws + + delta_sparse_code_probe_opr = probe.dictionary_matrix.H @ delta_U + + Gs = probe.dictionary_matrix @ delta_sparse_code_probe_opr @ Ws.T + TsHGsH = Ts.H * Gs.conj() + numer = torch.sum( torch.real(TsHGsH * chi_vec.T)) + denom = torch.sum( torch.real(TsHGsH * TsHGsH.conj())) + optimal_step_sparse_code_probe_opr = probe.options.eigenmode_update_relaxation * (numer / denom) + + sparse_code_probe_opr = probe.get_sparse_code_probe_opr_weights() + + optimal_sparse_code_probe_opr = (sparse_code_probe_opr + + optimal_step_sparse_code_probe_opr * delta_sparse_code_probe_opr.T) + + # Enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(optimal_sparse_code_probe_opr) + abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) + sel = abs_sparse_code_sorted[0][:, probe.sparse_code_probe_nnz] + sparse_code_mask = (abs_sparse_code >= sel[:,None]) + + # Hard or Soft thresholding + if probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'hard': + optimal_sparse_code_probe_opr = optimal_sparse_code_probe_opr * sparse_code_mask + elif probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'soft': + optimal_sparse_code_probe_opr = ( abs_sparse_code - sel[:,None] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_sparse_code_probe_opr)) + + probe.set_sparse_code_probe_opr(optimal_sparse_code_probe_opr) + + # Back to dense OPR representation + U = (probe.dictionary_matrix @ optimal_sparse_code_probe_opr.T).T + + # the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc)) + U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None] + + U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1])) + + probe_data[1:, 0, :, :] = U + weights_data[indices, 1:] = Ws.real + + # DELETE THIS FOR FINAL MERGING + # DELETE THIS FOR FINAL MERGING + + # Test the rank of the new scan position dependent probe: + + # probe_data_TEST = torch.reshape(probe_data[:,0,...], (probe_data.shape[0], probe_data.shape[-1] * probe_data.shape[-2])) + # Z1 = torch.sum(probe_data[:, 0, :, :][None,...] * weights_data[indices][...,None,None], 1) + # Z1 = torch.reshape(Z1, (Z1.shape[0], Z1.shape[1] * Z1.shape[2])) + # Z2 = probe_data_TEST.T @ weights_data[indices, :].T.to(torch.complex64) + # print( torch.linalg.matrix_rank(Z1) ) + # print( torch.linalg.matrix_rank(Z2) ) + + # DELETE THIS FOR FINAL MERGING + # DELETE THIS FOR FINAL MERGING + + else: + + # Optimal dense OPR mode updates: + + delta_U = -1 * Tsconj_chi @ Ws + + Ts_deltaU_Ws = Ts.T * (delta_U @ Ws.T) + numer = torch.sum(torch.real(chi_vec * Ts_deltaU_Ws.H)) + denom = torch.sum(torch.real( Ts_deltaU_Ws.conj() * Ts_deltaU_Ws )) + optimal_step_deltaU = probe.options.eigenmode_update_relaxation * (numer / denom) + + U = U + optimal_step_deltaU * delta_U.T + + # the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc)) + U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None] + + U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1])) + + probe_data[1:, 0, :, :] = U + weights_data[indices, 1:] = Ws.real + + else: + + # Ptychoshelves method for OPR updates + + # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. + relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation + relax_v = self.options.update_relaxation + # Shape of delta_p_i: (batch_size, n_probe_modes, h, w) + # Use only the first incoherent mode + delta_p_i = delta_p_i[:, 0, :, :] + delta_p_hat = delta_p_hat[0, :, :] + residue_update = delta_p_i - delta_p_hat + + # Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode. + for i_opr_mode in range(1, probe.n_opr_modes): + # Just take the first incoherent mode. + eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode) + weights_i = self.get_weights(indices)[:, i_opr_mode] + eigenmode_i, weights_i = self._update_first_eigenmode_and_weight( + residue_update, + eigenmode_i, + weights_i, + relax_u, + relax_v, + obj_patches, + chi, + update_eigenmode=update_eigenmode, + update_weights=self.eigenmode_weight_optimization_enabled(current_epoch), ) - probe_data[i_opr_mode, 0, :, :] = eigenmode_i - weights_data[indices, i_opr_mode] = weights_i + # Project residue on this eigenmode, then subtract it. + if i_opr_mode < probe.n_opr_modes - 1: + residue_update = residue_update - pmath.project( + residue_update, eigenmode_i, dim=(-2, -1) + ) - if probe.optimization_enabled(current_epoch): - probe.set_data(probe_data) - if self.eigenmode_weight_optimization_enabled(current_epoch): + probe_data[i_opr_mode, 0, :, :] = eigenmode_i + weights_data[indices, i_opr_mode] = weights_i + + if update_eigenmode: + probe.set_data(probe_data) + + if update_eigenmode_weights: self.set_data(weights_data) @timer() diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 6710825..73345da 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -221,7 +221,12 @@ def constrain_incoherent_modes_orthogonality(self): return probe = self.data - + + if self.options.orthogonalize_incoherent_modes.sort_by_occupancy: + shared_occupancy = torch.sum(torch.abs(probe[0,...])**2,(-2,-1)) / torch.sum(torch.abs(probe[0,...])**2) + shared_occupancy = torch.sort(shared_occupancy, dim=0, descending=True) + probe[0,...] = probe[ 0, shared_occupancy[1],...] + norm_first_mode_orig = pmath.norm(probe[0, 0], dim=(-2, -1)) if self.orthogonalize_incoherent_modes_method == "gs": @@ -462,32 +467,58 @@ def __init__(self, name = "probe", options = None, *args, **kwargs): super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs) - dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H = self.get_dictionary() + dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary() + self.register_buffer("dictionary_matrix", dictionary_matrix) self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv) - self.register_buffer("dictionary_matrix_H", dictionary_matrix_H) - probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 ) - self.register_buffer("probe_sparse_code_nnz", probe_sparse_code_nnz ) + sparse_code_probe_shared_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, dtype=torch.uint32 ) + sparse_code_probe_opr_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_nnz, dtype=torch.uint32 ) + + self.register_buffer("sparse_code_probe_nnz", sparse_code_probe_shared_nnz ) + self.register_buffer("sparse_code_opr_nnz", sparse_code_probe_opr_nnz ) + + sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() + sparse_code_probe_opr = self.get_sparse_code_probe_opr_weights() + + self.register_parameter("sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)) + self.register_parameter("sparse_code_probe_opr", torch.nn.Parameter(sparse_code_probe_opr)) - sparse_code_probe = self.get_sparse_code_weights() - self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe)) - self.build_optimizer() def get_dictionary(self): - dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.d_mat, dtype=torch.complex64 ) - dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_pinv, dtype=torch.complex64 ) - dictionary_matrix_H = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_conj_transpose, dtype=torch.complex64 ) - return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H - - def get_sparse_code_weights(self): - sz = self.data.shape - probe_vec = torch.reshape( self.data[0,...], (sz[1], sz[2] * sz[3])) - probe_vec = torch.swapaxes( probe_vec, 0, -1) - sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec - return sparse_code_probe + + dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix, dtype=torch.complex64 ) + dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix_pinv, dtype=torch.complex64 ) + + return dictionary_matrix, dictionary_matrix_pinv + def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions ): + + sz = probe_vs_scanpositions.shape + probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2]*sz[3])) + sparse_code_vs_scanpositions = torch.einsum('ij,klj->ikl', self.dictionary_matrix_pinv, probe_vec) + + return sparse_code_vs_scanpositions + + def get_sparse_code_probe_shared_weights(self): + + probe_shared = self.data[0,...] + sz = probe_shared.shape + probe_vec = torch.reshape(probe_shared, (sz[0], sz[1]*sz[2])) + sparse_code_probe_shared = self.dictionary_matrix_pinv @ probe_vec.T + + return sparse_code_probe_shared.T + + def get_sparse_code_probe_opr_weights(self): + + probe_opr = self.data[1:,0,...] + sz = probe_opr.shape + probe_vec = torch.reshape(probe_opr, (sz[0], sz[1]*sz[2])) + sparse_code_probe_opr = self.dictionary_matrix_pinv @ probe_vec.T + + return sparse_code_probe_opr.T + def generate(self): """Generate the probe using the sparse code, and set the generated probe to self.data. @@ -497,15 +528,51 @@ def generate(self): Tensor A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - probe_vec = self.dictionary_matrix @ self.sparse_code_probe - probe_vec = torch.swapaxes( probe_vec, 0, -1) - probe = torch.reshape(probe_vec, *[self.data[0,...].shape]) - probe = probe[None,...] - - # we only use sparse codes for the shared modes, not the OPRs - probe = torch.cat((probe, self.data[1:,...]), 0) - self.set_data(probe) + if (self.options.experimental.sdl_probe_options.enabled_shared + and self.options.experimental.sdl_probe_options.enabled_opr): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T + probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T + + probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) + probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) + + self.set_data(probe) + + elif (self.options.experimental.sdl_probe_options.enabled_shared + and not self.options.experimental.sdl_probe_options.enabled_opr): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T + + probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) + probe[1:,0,...] = self.data[1:,0,...] + + self.set_data(probe) + + elif (self.options.experimental.sdl_probe_options.enabled_opr + and not self.options.experimental.sdl_probe_options.enabled_shared): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T + + probe[0,...] = self.data[0,...] + probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) + + self.set_data(probe) + + else: + + probe = self.data + return probe def build_optimizer(self): @@ -514,12 +581,15 @@ def build_optimizer(self): "Parameter {} is optimizable but no optimizer is specified.".format(self.name) ) if self.optimizable: - self.optimizer = self.optimizer_class([self.sparse_code_probe], **self.optimizer_params) - - def set_sparse_code(self, data): - self.sparse_code_probe.data = data - + self.optimizer = self.optimizer_class([self.sparse_code_probe_shared], **self.optimizer_params) + def set_sparse_code_probe_shared(self, data): + self.sparse_code_probe_shared.data = data + + def set_sparse_code_probe_opr(self, data): + self.sparse_code_probe_opr.data = data + + class DIPProbe(Probe): options: "api.options.ad_ptychography.AutodiffPtychographyProbeOptions" diff --git a/src/ptychi/maths.py b/src/ptychi/maths.py index d19efc0..292aa86 100644 --- a/src/ptychi/maths.py +++ b/src/ptychi/maths.py @@ -213,6 +213,10 @@ def orthogonalize_svd( def project(a, b, dim=None): """Return complex vector projection of a onto b for along given axis.""" projected_length = inner(a, b, dim=dim, keepdims=True) / inner(b, b, dim=dim, keepdims=True) + + # if the inner product of b with itself has any zeros: + projected_length = torch.nan_to_num(projected_length, nan=0.0) + return projected_length * b def inner(x, y, dim=None, keepdims=False): diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 412c8d8..89919c0 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -225,13 +225,79 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) ) ) - # Calculate probe update direction. - delta_p_i_unshifted = self._calculate_probe_update_direction( - chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None - ) # Eq. 24a - delta_p_i = self.adjoint_shift_probe_update_direction( - indices, delta_p_i_unshifted, first_mode_only=True - ) + # TODO: move this to SynthesisDictLearnProbe class methods, so it can be used in rPIE as well + if (self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): + + rc = chi.shape[-1] * chi.shape[-2] + n_scpm = chi.shape[-3] + n_spos = chi.shape[-4] + + # sparse code update directions vs scan position and shared probe modes + obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) + delta_sparse_code = self.adjoint_shift_probe_update_direction(indices, chi * obj_patches_slice_i_conj[:, None, ... ], first_mode_only=True) + delta_sparse_code = torch.reshape( delta_sparse_code, + ( n_spos, n_scpm, rc )) + delta_sparse_code = torch.einsum('ijk,kl->lij', + delta_sparse_code, + self.parameter_group.probe.dictionary_matrix.conj()) + + # compute optimal step length for sparse code update + dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', + self.parameter_group.probe.dictionary_matrix, + delta_sparse_code) + + obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) + denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] + denom = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + denom) + + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True) + numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, + ( n_spos, n_scpm, rc )).permute(2,0,1) + numer = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + numer) + + # real is used to throw away small imag part due to numerical precision errors + optimal_step_sparse_code = ( numer / denom ).real + + optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code + + # Enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(optimal_delta_sparse_code) + abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + + sel = abs_sparse_code_sorted[0][self.parameter_group.probe.sparse_code_probe_nnz, ...] + sparse_code_mask = (abs_sparse_code >= sel[None,...]) + + # Hard or Soft thresholding + if self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': + optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask + elif self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': + optimal_delta_sparse_code = ( abs_sparse_code - sel[None,...] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) + + # update the shared probe sparse codes using the average over scan positions + sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() + sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T + self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) + + delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, + optimal_delta_sparse_code).permute(1, 2, 0) + delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) + delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + + else: + # Calculate probe update direction. + delta_p_i_unshifted = self._calculate_probe_update_direction( + chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None + ) # Eq. 24a + + delta_p_i = self.adjoint_shift_probe_update_direction( + indices, delta_p_i_unshifted, first_mode_only=True + ) + delta_p_hat = self._precondition_probe_update_direction(delta_p_i) # Eq. 25a # Update OPR modes and weights. @@ -239,6 +305,7 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): self.parameter_group.opr_mode_weights.update_variable_probe( self.parameter_group.probe, + self.adjoint_shift_probe_update_direction, indices, chi, delta_p_i, diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index ac032ac..dc2baef 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -135,50 +135,89 @@ def compute_updates( delta_p_i = None if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)): + if (self.parameter_group.probe.representation == "sparse_code"): - # TODO: move this into SynthesisDictLearnProbe class + + # TODO: move these into SynthesisDictLearnProbe class rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2] n_scpm = delta_exwv_i.shape[-3] n_spos = delta_exwv_i.shape[-4] - obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) - abs2_obj_patches = torch.abs(obj_patches_vec) ** 2 + obj_patches_conj = torch.conj( obj_patches[:, i_slice, ...]) + conjT_i_delta_exwv_i = obj_patches_conj[:, None,...] * delta_exwv_i + + # undo subpixel shifts and reshape + conjT_delta_exwv = self.adjoint_shift_probe_update_direction(indices, conjT_i_delta_exwv_i, first_mode_only=True) + conjT_delta_exwv_vec = torch.reshape( conjT_delta_exwv.permute( 2, 3, 0, 1 ), (rc, n_spos, n_scpm ) ) - z = torch.sum(abs2_obj_patches, dim = 0) - z_max = torch.max(z) - w = self.parameter_group.probe.options.alpha * (z_max - z) - z_plus_w = torch.swapaxes(z + w, 0, 1) + obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) + abs2_obj_patches = torch.abs( obj_patches_vec )**2 + + z_plus_w = torch.max(abs2_obj_patches, dim=0, keepdim=True)[0] + z_plus_w = self.parameter_group.probe.options.alpha * (z_plus_w - abs2_obj_patches) + z_plus_w = abs2_obj_patches + z_plus_w + + #================================ + # use average over scan positions - delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True) - delta_exwv = torch.sum(delta_exwv, 0) - delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T + if self.parameter_group.probe.use_avg_spos_sparse_code: + + z_plus_w = torch.sum( z_plus_w, 0 )[None,:] + conjT_delta_exwv_vec = torch.sum( conjT_delta_exwv_vec, 1 )[:,None,:] - denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix)) - numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv + #===== + + denom = torch.einsum('ij,ik,lki->jkl', + self.parameter_group.probe.dictionary_matrix.conj(), + self.parameter_group.probe.dictionary_matrix, + z_plus_w[:,None,...].to(torch.complex64)) + + numer = torch.einsum('ij,jlk->ilk', + self.parameter_group.probe.dictionary_matrix_H, + conjT_delta_exwv_vec) + + delta_sparse_code = torch.linalg.solve(denom.permute(2, 0, 1), numer.permute(1, 0, 2)) - delta_sparse_code = torch.linalg.solve(denom, numer) + # # If dictionary has bad condition number, use Tikhonov regularization? + # delta_sparse_code, _, _, _ = torch.linalg.lstsq(denom.permute(2, 0, 1), numer.permute(1, 0, 2), rcond=1e-6) + # delta_sparse_code = delta_sparse_code.permute(1, 0, 2) - delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code - delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2])) - delta_p_i = torch.tile(delta_p, (n_spos,1,1,1)) - - # sparse code update + delta_sparse_code_mean_spos = ( delta_sparse_code.mean(0).T )[None, ...] + sparse_code = self.parameter_group.probe.get_sparse_code_weights() - sparse_code = sparse_code + delta_sparse_code + sparse_code = sparse_code + delta_sparse_code_mean_spos + #=========================================== # Enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(sparse_code) - sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) - sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :] + sel = sparse_code_sorted[0][..., self.parameter_group.probe.probe_sparse_code_nnz] + #(TODO: soft thresholding option as default?) # hard thresholding: - sparse_code = sparse_code * (abs_sparse_code >= sel) - - #(TODO: soft thresholding option) - + sparse_code = sparse_code * (abs_sparse_code >= sel[...,None]) + + #============================================== # Update the new sparse code in the probe class + self.parameter_group.probe.set_sparse_code(sparse_code) + + #=============================================================== + # Create the probe update vs scan position using the sparse code + + delta_p_i = torch.einsum('ij,ljk->ilk', self.parameter_group.probe.dictionary_matrix, + delta_sparse_code) + delta_p_i = delta_p_i.permute(1,2,0) + + if self.parameter_group.probe.use_avg_spos_sparse_code: + delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) ) + + delta_p_i = torch.reshape(delta_p_i, ( n_spos, n_scpm, + delta_exwv_i.shape[-1], + delta_exwv_i.shape[-2] )) + else: step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...])) delta_p_i = step_weight * delta_exwv_i # get delta p at each position