From 3c1aeb7486677fe0676b80b41af5b0e3461c9bbf Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Thu, 31 Jul 2025 13:37:37 -0500 Subject: [PATCH 1/9] infrastructure for LSQML + sDL --- src/ptychi/api/options/lsqml.py | 9 +++++++-- src/ptychi/api/task.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/ptychi/api/options/lsqml.py b/src/ptychi/api/options/lsqml.py index 2a2f063..31c6364 100644 --- a/src/ptychi/api/options/lsqml.py +++ b/src/ptychi/api/options/lsqml.py @@ -100,6 +100,10 @@ class LSQMLObjectOptions(base.ObjectOptions): propagation always uses all probe modes regardless of this option. """ +@dataclasses.dataclass +class LSQMLProbeExperimentalOptions(base.Options): + sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions) + @dataclasses.dataclass class LSQMLProbeOptions(base.ProbeOptions): @@ -107,8 +111,9 @@ class LSQMLProbeOptions(base.ProbeOptions): """ A scaler for the solved optimal step size (beta_LSQ in PtychoShelves). """ - - + experimental: LSQMLProbeExperimentalOptions = dataclasses.field(default_factory=LSQMLProbeExperimentalOptions) + + @dataclasses.dataclass class LSQMLProbePositionOptions(base.ProbePositionOptions): pass diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index 86f76a3..1807b7b 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -169,7 +169,9 @@ def build_probe(self): ): self.probe = probe.DIPProbe(**kwargs) elif ( - isinstance(self.probe_options, api.options.PIEProbeOptions) + isinstance(self.probe_options, api.options.PIEProbeOptions) + or + isinstance(self.probe_options, api.options.LSQMLProbeOptions) ) and ( self.probe_options.experimental.sdl_probe_options.enabled ): From cb2f64153ef9c632852748161cc7d8c72e320a24 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Thu, 31 Jul 2025 15:42:30 -0500 Subject: [PATCH 2/9] calculation of sparse code update direction, next step is computing optimal step for sparse code update using uncoupled object and probe step calculation. after that, do the coupled step length calc --- src/ptychi/reconstructors/lsqml.py | 44 +++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 412c8d8..e9f0aff 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -225,13 +225,43 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) ) ) - # Calculate probe update direction. - delta_p_i_unshifted = self._calculate_probe_update_direction( - chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None - ) # Eq. 24a - delta_p_i = self.adjoint_shift_probe_update_direction( - indices, delta_p_i_unshifted, first_mode_only=True - ) + if (self.parameter_group.probe.representation == "sparse_code"): + + rc = chi.shape[-1] * chi.shape[-2] + n_scpm = chi.shape[-3] + n_spos = chi.shape[-4] + + obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) + + obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) + delta_sparse_code = chi * torch.conj( obj_patches_slice_i_conj[:, None, ... ] ) + delta_sparse_code = torch.reshape(delta_sparse_code, (n_spos, n_scpm, rc )) + delta_sparse_code = torch.swapaxes(delta_sparse_code, 0, -1 ) + delta_sparse_code = torch.swapaxes(delta_sparse_code, -2, -1 ) + + # Use einsum for efficient batch matrix multiplication + # delta_sparse_code: (256², 926, 5) + # dictionary_matrix_H.T: (256², 200) + # Result: (200, 926, 5) + delta_sparse_code = torch.einsum('ijk,il->ljk', delta_sparse_code, self.parameter_group.probe.dictionary_matrix_H.T) + + + + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True) + + # delta_sparse_code = self.parameter_group.probe.dictionary_matrix_H + # #delta_sparse_code = delta_sparse_code * + + else: + # Calculate probe update direction. + delta_p_i_unshifted = self._calculate_probe_update_direction( + chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None + ) # Eq. 24a + + delta_p_i = self.adjoint_shift_probe_update_direction( + indices, delta_p_i_unshifted, first_mode_only=True + ) + delta_p_hat = self._precondition_probe_update_direction(delta_p_i) # Eq. 25a # Update OPR modes and weights. From ec900f77127f06d783c0d5b4d25c24699727dd6f Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Tue, 5 Aug 2025 14:10:34 -0500 Subject: [PATCH 3/9] independent optimal step length calc for LSQML sample and sparse code; changes to rPIE and synthesissparseprobe class with cleaner tensor products using torch.einsum --- src/ptychi/api/options/base.py | 5 ++ src/ptychi/data_structures/probe.py | 18 +++--- src/ptychi/reconstructors/lsqml.py | 92 +++++++++++++++++++++++++---- src/ptychi/reconstructors/pie.py | 89 ++++++++++++++++++++-------- 4 files changed, 157 insertions(+), 47 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 66674bc..f1ae676 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -649,6 +649,11 @@ def get_non_data_fields(self) -> dict: @dataclasses.dataclass class SynthesisDictLearnProbeOptions(Options): + use_avg_spos_sparse_code: bool = True + """When computing the sparse code updates, we can either solve for + sparse codes that are scan position dependent or we can use the average + over scan positions before solving for the average sparse code.""" + d_mat: Union[ndarray, Tensor] = None """The synthesis sparse dictionary matrix; contains the basis functions that will be used to represent the probe via the sparse code weights.""" diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 6710825..a06ecba 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -473,6 +473,9 @@ def __init__(self, name = "probe", options = None, *args, **kwargs): sparse_code_probe = self.get_sparse_code_weights() self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe)) + use_avg_spos_sparse_code = self.options.experimental.sdl_probe_options.use_avg_spos_sparse_code + self.register_buffer("use_avg_spos_sparse_code", torch.tensor(use_avg_spos_sparse_code, dtype=torch.bool)) + self.build_optimizer() def get_dictionary(self): @@ -482,10 +485,11 @@ def get_dictionary(self): return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H def get_sparse_code_weights(self): + sz = self.data.shape - probe_vec = torch.reshape( self.data[0,...], (sz[1], sz[2] * sz[3])) - probe_vec = torch.swapaxes( probe_vec, 0, -1) - sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec + probe_vec = torch.reshape(self.data, (sz[0], sz[1], sz[2] * sz[3])) + sparse_code_probe = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) + return sparse_code_probe def generate(self): @@ -497,13 +501,9 @@ def generate(self): Tensor A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - probe_vec = self.dictionary_matrix @ self.sparse_code_probe - probe_vec = torch.swapaxes( probe_vec, 0, -1) - probe = torch.reshape(probe_vec, *[self.data[0,...].shape]) - probe = probe[None,...] - # we only use sparse codes for the shared modes, not the OPRs - probe = torch.cat((probe, self.data[1:,...]), 0) + probe = torch.einsum('ij,klj->kli', self.dictionary_matrix, self.sparse_code_probe) + probe = torch.reshape( probe, *[self.data.shape] ) self.set_data(probe) return probe diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index e9f0aff..bd59570 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -230,27 +230,93 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) rc = chi.shape[-1] * chi.shape[-2] n_scpm = chi.shape[-3] n_spos = chi.shape[-4] - - obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) + + #====================================================================== + # sparse code update directions vs scan position and shared probe modes obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) - delta_sparse_code = chi * torch.conj( obj_patches_slice_i_conj[:, None, ... ] ) - delta_sparse_code = torch.reshape(delta_sparse_code, (n_spos, n_scpm, rc )) - delta_sparse_code = torch.swapaxes(delta_sparse_code, 0, -1 ) - delta_sparse_code = torch.swapaxes(delta_sparse_code, -2, -1 ) + + delta_sparse_code = chi * obj_patches_slice_i_conj[:, None, ... ] + delta_sparse_code = self.adjoint_shift_probe_update_direction(indices, delta_sparse_code, first_mode_only=True) + + delta_sparse_code = torch.reshape( delta_sparse_code, + ( n_spos, n_scpm, rc )) - # Use einsum for efficient batch matrix multiplication - # delta_sparse_code: (256², 926, 5) - # dictionary_matrix_H.T: (256², 200) - # Result: (200, 926, 5) - delta_sparse_code = torch.einsum('ijk,il->ljk', delta_sparse_code, self.parameter_group.probe.dictionary_matrix_H.T) + delta_sparse_code = torch.einsum('ijk,kl->lij', + delta_sparse_code, + self.parameter_group.probe.dictionary_matrix_H.T) + + #=================================================== + # compute optimal step length for sparse code update + + dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', + self.parameter_group.probe.dictionary_matrix, + delta_sparse_code) + obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) + + denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] + denom = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + denom) + + #===== chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True) + + numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, + ( n_spos, n_scpm, rc )).permute(2,0,1) + numer = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + numer) + + # real is used to throw away small imag part due to numerical precision errors + optimal_step_sparse_code = ( numer / denom ).real + + #===== + + optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code + + optimal_delta_sparse_code_mean_spos = ( optimal_delta_sparse_code.mean(1).T )[None, ...] + + # sparse code update + sparse_code = self.parameter_group.probe.get_sparse_code_weights() + sparse_code = sparse_code + optimal_delta_sparse_code_mean_spos + + #=========================================== + # Enforce sparsity constraint on sparse code + + abs_sparse_code = torch.abs(sparse_code) + sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) + + sel = sparse_code_sorted[0][...,self.parameter_group.probe.probe_sparse_code_nnz] + + # hard thresholding: + sparse_code = sparse_code * (abs_sparse_code >= sel[...,None]) + + #(TODO: soft thresholding option) + + #============================================== + # Update the new sparse code in the probe class + + self.parameter_group.probe.set_sparse_code(sparse_code) + + #=============================================================== + # Create the probe update vs scan position using the sparse code + + delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, + optimal_delta_sparse_code) + delta_p_i = delta_p_i.permute(1,2,0) + + # if self.parameter_group.probe.use_avg_spos_sparse_code: + # delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) ) + + delta_p_i = torch.reshape(delta_p_i, ( n_spos, n_scpm, + chi.shape[-1], + chi.shape[-2] )) - # delta_sparse_code = self.parameter_group.probe.dictionary_matrix_H - # #delta_sparse_code = delta_sparse_code * + delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) else: # Calculate probe update direction. diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index ac032ac..dc2baef 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -135,50 +135,89 @@ def compute_updates( delta_p_i = None if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)): + if (self.parameter_group.probe.representation == "sparse_code"): - # TODO: move this into SynthesisDictLearnProbe class + + # TODO: move these into SynthesisDictLearnProbe class rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2] n_scpm = delta_exwv_i.shape[-3] n_spos = delta_exwv_i.shape[-4] - obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) - abs2_obj_patches = torch.abs(obj_patches_vec) ** 2 + obj_patches_conj = torch.conj( obj_patches[:, i_slice, ...]) + conjT_i_delta_exwv_i = obj_patches_conj[:, None,...] * delta_exwv_i + + # undo subpixel shifts and reshape + conjT_delta_exwv = self.adjoint_shift_probe_update_direction(indices, conjT_i_delta_exwv_i, first_mode_only=True) + conjT_delta_exwv_vec = torch.reshape( conjT_delta_exwv.permute( 2, 3, 0, 1 ), (rc, n_spos, n_scpm ) ) - z = torch.sum(abs2_obj_patches, dim = 0) - z_max = torch.max(z) - w = self.parameter_group.probe.options.alpha * (z_max - z) - z_plus_w = torch.swapaxes(z + w, 0, 1) + obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) + abs2_obj_patches = torch.abs( obj_patches_vec )**2 + + z_plus_w = torch.max(abs2_obj_patches, dim=0, keepdim=True)[0] + z_plus_w = self.parameter_group.probe.options.alpha * (z_plus_w - abs2_obj_patches) + z_plus_w = abs2_obj_patches + z_plus_w + + #================================ + # use average over scan positions - delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True) - delta_exwv = torch.sum(delta_exwv, 0) - delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T + if self.parameter_group.probe.use_avg_spos_sparse_code: + + z_plus_w = torch.sum( z_plus_w, 0 )[None,:] + conjT_delta_exwv_vec = torch.sum( conjT_delta_exwv_vec, 1 )[:,None,:] - denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix)) - numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv + #===== + + denom = torch.einsum('ij,ik,lki->jkl', + self.parameter_group.probe.dictionary_matrix.conj(), + self.parameter_group.probe.dictionary_matrix, + z_plus_w[:,None,...].to(torch.complex64)) + + numer = torch.einsum('ij,jlk->ilk', + self.parameter_group.probe.dictionary_matrix_H, + conjT_delta_exwv_vec) + + delta_sparse_code = torch.linalg.solve(denom.permute(2, 0, 1), numer.permute(1, 0, 2)) - delta_sparse_code = torch.linalg.solve(denom, numer) + # # If dictionary has bad condition number, use Tikhonov regularization? + # delta_sparse_code, _, _, _ = torch.linalg.lstsq(denom.permute(2, 0, 1), numer.permute(1, 0, 2), rcond=1e-6) + # delta_sparse_code = delta_sparse_code.permute(1, 0, 2) - delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code - delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2])) - delta_p_i = torch.tile(delta_p, (n_spos,1,1,1)) - - # sparse code update + delta_sparse_code_mean_spos = ( delta_sparse_code.mean(0).T )[None, ...] + sparse_code = self.parameter_group.probe.get_sparse_code_weights() - sparse_code = sparse_code + delta_sparse_code + sparse_code = sparse_code + delta_sparse_code_mean_spos + #=========================================== # Enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(sparse_code) - sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) - sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :] + sel = sparse_code_sorted[0][..., self.parameter_group.probe.probe_sparse_code_nnz] + #(TODO: soft thresholding option as default?) # hard thresholding: - sparse_code = sparse_code * (abs_sparse_code >= sel) - - #(TODO: soft thresholding option) - + sparse_code = sparse_code * (abs_sparse_code >= sel[...,None]) + + #============================================== # Update the new sparse code in the probe class + self.parameter_group.probe.set_sparse_code(sparse_code) + + #=============================================================== + # Create the probe update vs scan position using the sparse code + + delta_p_i = torch.einsum('ij,ljk->ilk', self.parameter_group.probe.dictionary_matrix, + delta_sparse_code) + delta_p_i = delta_p_i.permute(1,2,0) + + if self.parameter_group.probe.use_avg_spos_sparse_code: + delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) ) + + delta_p_i = torch.reshape(delta_p_i, ( n_spos, n_scpm, + delta_exwv_i.shape[-1], + delta_exwv_i.shape[-2] )) + else: step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...])) delta_p_i = step_weight * delta_exwv_i # get delta p at each position From f0ed9dd8b85a6e88eb062be597d9a2c5bcec2c51 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 6 Aug 2025 11:11:56 -0500 Subject: [PATCH 4/9] lsqml correct step length --- src/ptychi/reconstructors/lsqml.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index bd59570..fb12a30 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -637,6 +637,34 @@ def calculate_object_and_probe_update_step_sizes( alpha_vec = torch.linalg.solve(a_mat, b_vec) alpha_vec = alpha_vec.real.clip(0, None) + + + + ''' + + # NOTE: FoldSlice does not use real() for off-diagonal terms, and it only computes this for the dominant mode (we use all incoherent modes here) + # sum over r (spatial pixel index), p (incoherent mode index) + a11 = torch.sum(delta_o_patches_p.abs() ** 2, dim=(-1, -2, -3)) + 1e-6 + a11 += 0.5 * torch.mean(a11, dim=0) + a22 = torch.sum(delta_p_o.abs() ** 2, dim=(-1, -2, -3)) + 1e-6 + a22 += 0.5 * torch.mean(a22, dim=0) + a12 = torch.sum(torch.real(delta_o_patches_p * delta_p_o.conj()), dim=(-1, -2, -3)) + a21 = a12 + b1 = torch.sum(torch.real(delta_o_patches_p.conj() * chi), dim=(-1, -2, -3)) + b2 = torch.sum(torch.real(delta_p_o.conj() * chi), dim=(-1, -2, -3)) + + a_mat = torch.stack([a11, a12, a21, a22], dim=1).view(-1, 2, 2) + b_vec = torch.stack([b1, b2], dim=1).view(-1, 2).type(a_mat.dtype) + alpha_vec2 = torch.linalg.solve(a_mat, b_vec) + + # alpha_vec = torch.where(alpha_vec < 0, 0, alpha_vec) + + ''' + + + + + alpha_o_i = alpha_vec[:, 0] alpha_p_i = alpha_vec[:, 1] From 519b961afe9af31000d8d34244f8e497de3b1e76 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 13 Aug 2025 16:28:42 -0500 Subject: [PATCH 5/9] added dictionary learning to the OPR mode updates, I still need to define a separate representation = 'sparse_code' to the OPR options since I'm using the shared probe update representation currently. --- .../data_structures/opr_mode_weights.py | 73 ++++++++++++++++++- src/ptychi/reconstructors/lsqml.py | 32 +------- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/src/ptychi/data_structures/opr_mode_weights.py b/src/ptychi/data_structures/opr_mode_weights.py index 7c0c032..ef3de08 100644 --- a/src/ptychi/data_structures/opr_mode_weights.py +++ b/src/ptychi/data_structures/opr_mode_weights.py @@ -101,6 +101,7 @@ def update_variable_probe( chi: Tensor, delta_p_i: Tensor, delta_p_hat: Tensor, + probe_current_slice: Tensor, obj_patches: Tensor, current_epoch: int, probe_mode_index: Optional[int] = None, @@ -118,7 +119,7 @@ def update_variable_probe( or (self.eigenmode_weight_optimization_enabled(current_epoch)) ): self.update_opr_probe_modes_and_weights( - probe, indices, chi, delta_p_i, delta_p_hat, obj_patches, current_epoch + probe, indices, chi, delta_p_i, delta_p_hat, probe_current_slice, obj_patches, current_epoch ) if self.intensity_variation_optimization_enabled(current_epoch): @@ -138,6 +139,7 @@ def update_opr_probe_modes_and_weights( chi: Tensor, delta_p_i: Tensor, delta_p_hat: Tensor, + probe_current_slice: Tensor, obj_patches: Tensor, current_epoch: int, ): @@ -158,6 +160,69 @@ def update_opr_probe_modes_and_weights( if batch_size == 1: return + update_eigenmode = probe.optimization_enabled(current_epoch) + + # TODO: need to introduce a separate OPR mode representation + # parameter; this just uses the probe representation parameter. + # This is for if we want to e.g. use dense representation for + # the shared probe modes and sparse representation for the OPRs. + + if (probe.representation == "sparse_code") and update_eigenmode: + + sz = delta_p_i.shape + rc = sz[-2] * sz[-1] + Nspos = sz[0] + + weights_ALL = self.get_weights(indices)[:,1:] + + probe_current_slice_vec = torch.reshape( probe_current_slice[:,0,...], (Nspos, rc) ).T + + #================================================================================== + # need to use regularization here because weights_ALL has terrible condition number + + sparse_code_opr = probe.dictionary_matrix_pinv @ probe_current_slice_vec + + A = weights_ALL.to(torch.complex64) + b = sparse_code_opr.T + + lambda_value = 1e-2 # Tikhonov regularization parameter + L = torch.eye(A.shape[1]) # L matrix (identity for standard Tikhonov) + + A_reg = torch.cat((A, lambda_value * L), dim=0) + b_reg = torch.cat((b, torch.zeros(L.shape[0], b.shape[1])), dim=0) + + sparse_code_opr = torch.linalg.lstsq(A_reg, b_reg).solution.T + + #print( torch.linalg.norm( sparse_code_opr, dim=0 ) ) + + #============================ + # enforce sparsity constraint + + abs_sparse_code_opr = torch.abs(sparse_code_opr) + abs_sparse_code_opr_sorted = torch.sort(abs_sparse_code_opr, dim=0, descending=True) + + sparsity_nnz = 50 + + sel = abs_sparse_code_opr_sorted[0][sparsity_nnz,:] + + sparse_code_opr = sparse_code_opr * (abs_sparse_code_opr >= sel) + + eigenmodes_updated = probe.dictionary_matrix @ sparse_code_opr + + eigenmodes_updated = torch.reshape( eigenmodes_updated.T, ( eigenmodes_updated.shape[-1], + sz[-2], sz[-1] )) + + # scaling_ratio = torch.linalg.norm( probe.data[1:,0,...], dim=(-1,-2)) / torch.linalg.norm( eigenmodes_updated, dim=(-1,-2)) + + # print( torch.linalg.norm( eigenmodes_updated, dim=(-1,-2))) + + # print( torch.linalg.norm( probe.data[1:,0,...], dim=(-1,-2))) + + probe_data[1:, 0, :, :] = eigenmodes_updated + probe.set_data(probe_data) + + update_eigenmode = False + # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation relax_v = self.options.update_relaxation @@ -180,7 +245,7 @@ def update_opr_probe_modes_and_weights( relax_v, obj_patches, chi, - update_eigenmode=probe.optimization_enabled(current_epoch), + update_eigenmode = update_eigenmode, update_weights=self.eigenmode_weight_optimization_enabled(current_epoch), ) @@ -193,8 +258,8 @@ def update_opr_probe_modes_and_weights( probe_data[i_opr_mode, 0, :, :] = eigenmode_i weights_data[indices, i_opr_mode] = weights_i - if probe.optimization_enabled(current_epoch): - probe.set_data(probe_data) + if probe.optimization_enabled(current_epoch): # and not (probe.representation == "sparse_code") + probe.set_data(probe_data) # do we need to do this again if using sparse code? if self.eigenmode_weight_optimization_enabled(current_epoch): self.set_data(weights_data) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index fb12a30..6e0975b 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -309,9 +309,6 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) optimal_delta_sparse_code) delta_p_i = delta_p_i.permute(1,2,0) - # if self.parameter_group.probe.use_avg_spos_sparse_code: - # delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) ) - delta_p_i = torch.reshape(delta_p_i, ( n_spos, n_scpm, chi.shape[-1], chi.shape[-2] )) @@ -339,6 +336,7 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) chi, delta_p_i, delta_p_hat, + probe_current_slice, obj_patches, self.current_epoch, probe_mode_index=0, @@ -637,34 +635,6 @@ def calculate_object_and_probe_update_step_sizes( alpha_vec = torch.linalg.solve(a_mat, b_vec) alpha_vec = alpha_vec.real.clip(0, None) - - - - ''' - - # NOTE: FoldSlice does not use real() for off-diagonal terms, and it only computes this for the dominant mode (we use all incoherent modes here) - # sum over r (spatial pixel index), p (incoherent mode index) - a11 = torch.sum(delta_o_patches_p.abs() ** 2, dim=(-1, -2, -3)) + 1e-6 - a11 += 0.5 * torch.mean(a11, dim=0) - a22 = torch.sum(delta_p_o.abs() ** 2, dim=(-1, -2, -3)) + 1e-6 - a22 += 0.5 * torch.mean(a22, dim=0) - a12 = torch.sum(torch.real(delta_o_patches_p * delta_p_o.conj()), dim=(-1, -2, -3)) - a21 = a12 - b1 = torch.sum(torch.real(delta_o_patches_p.conj() * chi), dim=(-1, -2, -3)) - b2 = torch.sum(torch.real(delta_p_o.conj() * chi), dim=(-1, -2, -3)) - - a_mat = torch.stack([a11, a12, a21, a22], dim=1).view(-1, 2, 2) - b_vec = torch.stack([b1, b2], dim=1).view(-1, 2).type(a_mat.dtype) - alpha_vec2 = torch.linalg.solve(a_mat, b_vec) - - # alpha_vec = torch.where(alpha_vec < 0, 0, alpha_vec) - - ''' - - - - - alpha_o_i = alpha_vec[:, 0] alpha_p_i = alpha_vec[:, 1] From 91773fe9d90d32ba361098de20ea7303228f6848 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Tue, 19 Aug 2025 17:02:38 -0500 Subject: [PATCH 6/9] Work in progress reflecting some of the changes Ming requested, still implementing the other suggestions from Ming --- src/ptychi/api/options/base.py | 30 ++++-- .../data_structures/opr_mode_weights.py | 102 +++++++++++------- src/ptychi/data_structures/probe.py | 77 +++++++++---- src/ptychi/maths.py | 4 + src/ptychi/reconstructors/lsqml.py | 16 +-- 5 files changed, 152 insertions(+), 77 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index f1ae676..9fc8f87 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -654,24 +654,38 @@ class SynthesisDictLearnProbeOptions(Options): sparse codes that are scan position dependent or we can use the average over scan positions before solving for the average sparse code.""" - d_mat: Union[ndarray, Tensor] = None + dictionary_matrix: Union[ndarray, Tensor] = None """The synthesis sparse dictionary matrix; contains the basis functions that will be used to represent the probe via the sparse code weights.""" - d_mat_conj_transpose: Union[ndarray, Tensor] = None - """Conjugate transpose of the synthesis sparse dictionary matrix.""" - - d_mat_pinv: Union[ndarray, Tensor] = None + dictionary_matrix_pinv: Union[ndarray, Tensor] = None """Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix.""" - probe_sparse_code: Union[ndarray, Tensor] = None + sparse_code_probe_shared: Union[ndarray, Tensor] = None """Sparse code weights vector.""" - probe_sparse_code_nnz: float = None + sparse_code_probe_shared_nnz: float = None + """Number of non-zeros we will keep when enforcing sparsity constraint on + the sparse code weights vector sparse_code_probe.""" + + sparse_code_probe_opr: Union[ndarray, Tensor] = None + """Sparse code weights vector for the OPRs.""" + + sparse_code_probe_opr_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector probe_sparse_code.""" + the sparse code weights vector sparse_code_opr.""" + + sparse_code_probe_shared_start: int = torch.inf + sparse_code_probe_shared_stop: int = torch.inf + sparse_code_probe_shared_stride: int = 1 + + sparse_code_probe_opr_start: int = torch.inf + sparse_code_probe_opr_stop: int = torch.inf + sparse_code_probe_opr_stride: int = 1 enabled: bool = False + enabled_shared: bool = False + enabled_opr: bool = False @dataclasses.dataclass class PositionCorrectionOptions(Options): diff --git a/src/ptychi/data_structures/opr_mode_weights.py b/src/ptychi/data_structures/opr_mode_weights.py index ef3de08..987e4e6 100644 --- a/src/ptychi/data_structures/opr_mode_weights.py +++ b/src/ptychi/data_structures/opr_mode_weights.py @@ -151,7 +151,9 @@ def update_opr_probe_modes_and_weights( """ probe_data = probe.data weights_data = self.data - + + # print( torch.linalg.norm(probe_data,dim=(-1,-2)), torch.linalg.norm( weights_data,dim=(0)) ) + batch_size = len(delta_p_i) n_points_total = self.n_scan_points @@ -161,66 +163,84 @@ def update_opr_probe_modes_and_weights( return update_eigenmode = probe.optimization_enabled(current_epoch) - - # TODO: need to introduce a separate OPR mode representation - # parameter; this just uses the probe representation parameter. - # This is for if we want to e.g. use dense representation for - # the shared probe modes and sparse representation for the OPRs. - - if (probe.representation == "sparse_code") and update_eigenmode: + + if (update_eigenmode + and probe.representation == "sparse_code" + and probe.options.experimental.sdl_probe_options.enabled_opr + and probe.options.experimental.sdl_probe_options.sparse_code_probe_opr_start >= current_epoch + and probe.options.experimental.sdl_probe_options.sparse_code_probe_opr_stop <= current_epoch + and (current_epoch % probe.options.experimental.sdl_probe_options.sparse_code_probe_opr_stride) == 0 + ): sz = delta_p_i.shape - rc = sz[-2] * sz[-1] - Nspos = sz[0] - weights_ALL = self.get_weights(indices)[:,1:] - probe_current_slice_vec = torch.reshape( probe_current_slice[:,0,...], (Nspos, rc) ).T - - #================================================================================== - # need to use regularization here because weights_ALL has terrible condition number + #print( torch.linalg.cond( weights_ALL ) ) + probe_current_slice_vec = torch.reshape( probe_current_slice[:,0,...], (sz[0], sz[-2] * sz[-1]) ).T sparse_code_opr = probe.dictionary_matrix_pinv @ probe_current_slice_vec - A = weights_ALL.to(torch.complex64) - b = sparse_code_opr.T + # need to use regularization here because weights_ALL has terrible condition number + if torch.linalg.cond( weights_ALL ) > 5: - lambda_value = 1e-2 # Tikhonov regularization parameter - L = torch.eye(A.shape[1]) # L matrix (identity for standard Tikhonov) + A = weights_ALL.to(torch.complex64) + b = sparse_code_opr.T + + lambda_value = 1e-3 # Tikhonov regularization parameter + L = torch.eye(A.shape[1]) # L matrix (identity for standard Tikhonov) - A_reg = torch.cat((A, lambda_value * L), dim=0) - b_reg = torch.cat((b, torch.zeros(L.shape[0], b.shape[1])), dim=0) - - sparse_code_opr = torch.linalg.lstsq(A_reg, b_reg).solution.T + A_reg = torch.cat((A, lambda_value * L), dim=0) + b_reg = torch.cat((b, torch.zeros(L.shape[0], b.shape[1])), dim=0) + sparse_code_opr = torch.linalg.lstsq(A_reg, b_reg).solution.T + else: + sparse_code_opr = sparse_code_opr @ torch.linalg.pinv( weights_ALL.to(torch.complex64) ).T + + #print( torch.linalg.cond( sparse_code_opr ) ) #print( torch.linalg.norm( sparse_code_opr, dim=0 ) ) - #============================ # enforce sparsity constraint - + + sparsity_nnz = 195 abs_sparse_code_opr = torch.abs(sparse_code_opr) - abs_sparse_code_opr_sorted = torch.sort(abs_sparse_code_opr, dim=0, descending=True) - sparsity_nnz = 50 + # abs_sparse_code_opr_sorted = torch.sort(abs_sparse_code_opr, dim=0, descending=True) + # sel = abs_sparse_code_opr_sorted[0][sparsity_nnz,:] + sel = torch.sort(abs_sparse_code_opr, dim=0, descending=True)[0][sparsity_nnz,:] + sparse_code_opr_mask = (abs_sparse_code_opr >= sel) + + sparse_code_opr = sparse_code_opr * sparse_code_opr_mask # Hard Thresholding + #sparse_code_opr = ( abs_sparse_code_opr - sel) * sparse_code_opr_mask * torch.exp(1j * torch.angle(sparse_code_opr)) # Soft Thresholding + + # SET THE NEW OPR SPARSE CODES + + probe.options.experimental.sdl_probe_options.set_sparse_code_probe_opr( sparse_code_opr ) - sel = abs_sparse_code_opr_sorted[0][sparsity_nnz,:] - sparse_code_opr = sparse_code_opr * (abs_sparse_code_opr >= sel) - + + + # Use updated OPR sparse code to generate the dense representation OPR modes eigenmodes_updated = probe.dictionary_matrix @ sparse_code_opr - eigenmodes_updated = torch.reshape( eigenmodes_updated.T, ( eigenmodes_updated.shape[-1], - sz[-2], sz[-1] )) + eigenmodes_updated = torch.reshape(eigenmodes_updated.T, + (eigenmodes_updated.shape[-1], sz[-2], sz[-1])) + + #================================================ + # rescale so that all OPR modes have correct norm - # scaling_ratio = torch.linalg.norm( probe.data[1:,0,...], dim=(-1,-2)) / torch.linalg.norm( eigenmodes_updated, dim=(-1,-2)) - - # print( torch.linalg.norm( eigenmodes_updated, dim=(-1,-2))) + eigenmodes_current_scaling = torch.linalg.norm(eigenmodes_updated, axis=(-1,-2), keepdims=True) + eigenmodes_updated_scaling = torch.sqrt(torch.prod(torch.tensor(eigenmodes_updated.shape[-2:]))) + + scaling_ratio = eigenmodes_updated_scaling / eigenmodes_current_scaling + + eigenmodes_updated = eigenmodes_updated * scaling_ratio - # print( torch.linalg.norm( probe.data[1:,0,...], dim=(-1,-2))) + #===== - probe_data[1:, 0, :, :] = eigenmodes_updated - probe.set_data(probe_data) + w = 1 - 0e-6 + probe_data[1:, 0, :, :] = w * eigenmodes_updated + ( 1 - w ) * probe_data[:, 0, :, :] + probe.set_data(probe_data) update_eigenmode = False # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. @@ -258,8 +278,10 @@ def update_opr_probe_modes_and_weights( probe_data[i_opr_mode, 0, :, :] = eigenmode_i weights_data[indices, i_opr_mode] = weights_i - if probe.optimization_enabled(current_epoch): # and not (probe.representation == "sparse_code") - probe.set_data(probe_data) # do we need to do this again if using sparse code? + # print( torch.linalg.norm(probe_data,dim=(-1,-2)), torch.linalg.norm( weights_data,dim=(0)) ) + + if update_eigenmode and not (probe.representation == "sparse_code"): + probe.set_data(probe_data) if self.eigenmode_weight_optimization_enabled(current_epoch): self.set_data(weights_data) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index a06ecba..8a93931 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -462,35 +462,59 @@ def __init__(self, name = "probe", options = None, *args, **kwargs): super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs) - dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H = self.get_dictionary() + dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary() self.register_buffer("dictionary_matrix", dictionary_matrix) self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv) - self.register_buffer("dictionary_matrix_H", dictionary_matrix_H) + + sparse_code_probe_shared_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, dtype=torch.uint32 ) + self.register_buffer("sparse_code_probe_nnz", sparse_code_probe_shared_nnz ) - probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 ) - self.register_buffer("probe_sparse_code_nnz", probe_sparse_code_nnz ) + sparse_code_probe_opr_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_nnz, dtype=torch.uint32 ) + self.register_buffer("sparse_code_opr_nnz", sparse_code_probe_opr_nnz ) - sparse_code_probe = self.get_sparse_code_weights() - self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe)) + sparse_code_probe_shared = self.get_sparse_code_weights() + self.register_parameter("sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)) + + sparse_code_probe_opr = self.get_sparse_code_weights() + self.register_parameter("sparse_code_probe_opr", torch.nn.Parameter(sparse_code_probe_opr)) use_avg_spos_sparse_code = self.options.experimental.sdl_probe_options.use_avg_spos_sparse_code self.register_buffer("use_avg_spos_sparse_code", torch.tensor(use_avg_spos_sparse_code, dtype=torch.bool)) + sparse_code_probe_shared_start = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_start ) + self.register_buffer("sparse_code_probe_shared_start", sparse_code_probe_shared_start ) + + sparse_code_probe_shared_stop = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_stop ) + self.register_buffer("sparse_code_probe_shared_stop", sparse_code_probe_shared_stop ) + + sparse_code_probe_shared_stride = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_stride ) + self.register_buffer("sparse_code_probe_shared_stride", sparse_code_probe_shared_stride ) + + sparse_code_probe_opr_start = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_start ) + self.register_buffer("sparse_code_probe_opr_start", sparse_code_probe_opr_start ) + + sparse_code_probe_opr_stop = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_stop ) + self.register_buffer("sparse_code_probe_opr_stop", sparse_code_probe_opr_stop ) + + sparse_code_probe_opr_stride = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_stride ) + self.register_buffer("sparse_code_probe_opr_stride", sparse_code_probe_opr_stride ) + self.build_optimizer() def get_dictionary(self): - dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.d_mat, dtype=torch.complex64 ) - dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_pinv, dtype=torch.complex64 ) - dictionary_matrix_H = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_conj_transpose, dtype=torch.complex64 ) - return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H + + dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix, dtype=torch.complex64 ) + dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix_pinv, dtype=torch.complex64 ) + + return dictionary_matrix, dictionary_matrix_pinv def get_sparse_code_weights(self): sz = self.data.shape probe_vec = torch.reshape(self.data, (sz[0], sz[1], sz[2] * sz[3])) - sparse_code_probe = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) + sparse_code = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) - return sparse_code_probe + return sparse_code def generate(self): """Generate the probe using the sparse code, and set the @@ -502,10 +526,18 @@ def generate(self): A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - probe = torch.einsum('ij,klj->kli', self.dictionary_matrix, self.sparse_code_probe) - probe = torch.reshape( probe, *[self.data.shape] ) - - self.set_data(probe) + if (self.options.experimental.sdl_probe_options.enabled_shared): + + # GENERATE OPR MODES FOR probe[1:,...] HERE USING self.sparse_code_probe_opr + + probe = torch.einsum('ij,klj->kli', self.dictionary_matrix, self.sparse_code_probe_shared) + probe = torch.reshape( probe, *[self.data.shape] ) + + self.set_data(probe) + + else: + probe = self.probe.data + return probe def build_optimizer(self): @@ -514,12 +546,15 @@ def build_optimizer(self): "Parameter {} is optimizable but no optimizer is specified.".format(self.name) ) if self.optimizable: - self.optimizer = self.optimizer_class([self.sparse_code_probe], **self.optimizer_params) - - def set_sparse_code(self, data): - self.sparse_code_probe.data = data - + self.optimizer = self.optimizer_class([self.sparse_code_probe_shared], **self.optimizer_params) + def set_sparse_code_probe_shared(self, data): + self.sparse_code_probe_shared.data = data + + def set_sparse_code_probe_opr(self, data): + self.sparse_code_probe_opr.data = data + + class DIPProbe(Probe): options: "api.options.ad_ptychography.AutodiffPtychographyProbeOptions" diff --git a/src/ptychi/maths.py b/src/ptychi/maths.py index d19efc0..292aa86 100644 --- a/src/ptychi/maths.py +++ b/src/ptychi/maths.py @@ -213,6 +213,10 @@ def orthogonalize_svd( def project(a, b, dim=None): """Return complex vector projection of a onto b for along given axis.""" projected_length = inner(a, b, dim=dim, keepdims=True) / inner(b, b, dim=dim, keepdims=True) + + # if the inner product of b with itself has any zeros: + projected_length = torch.nan_to_num(projected_length, nan=0.0) + return projected_length * b def inner(x, y, dim=None, keepdims=False): diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 6e0975b..422c64d 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -225,7 +225,7 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) ) ) - if (self.parameter_group.probe.representation == "sparse_code"): + if (self.parameter_group.probe.representation == "sparse_code") and (self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): rc = chi.shape[-1] * chi.shape[-2] n_scpm = chi.shape[-3] @@ -244,7 +244,7 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) delta_sparse_code = torch.einsum('ijk,kl->lij', delta_sparse_code, - self.parameter_group.probe.dictionary_matrix_H.T) + self.parameter_group.probe.dictionary_matrix.conj()) #=================================================== # compute optimal step length for sparse code update @@ -278,6 +278,7 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code + # TODO: At some point we'll want to use scan position dependent sparse codes instead of this averaged sparse code optimal_delta_sparse_code_mean_spos = ( optimal_delta_sparse_code.mean(1).T )[None, ...] # sparse code update @@ -286,11 +287,12 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) #=========================================== # Enforce sparsity constraint on sparse code + # !!!!! MOVE THIS TO BEFORE OPTIMAL STEP CALC !!!!! abs_sparse_code = torch.abs(sparse_code) sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) - sel = sparse_code_sorted[0][...,self.parameter_group.probe.probe_sparse_code_nnz] + sel = sparse_code_sorted[0][...,self.parameter_group.probe.sparse_code_probe_nnz] # hard thresholding: sparse_code = sparse_code * (abs_sparse_code >= sel[...,None]) @@ -300,18 +302,16 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) #============================================== # Update the new sparse code in the probe class - self.parameter_group.probe.set_sparse_code(sparse_code) + self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code) #=============================================================== # Create the probe update vs scan position using the sparse code delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, optimal_delta_sparse_code) - delta_p_i = delta_p_i.permute(1,2,0) + delta_p_i = delta_p_i.permute(1, 2, 0) - delta_p_i = torch.reshape(delta_p_i, ( n_spos, n_scpm, - chi.shape[-1], - chi.shape[-2] )) + delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) From 5979d661a71ead65f6b03d96ea76d040d0147feb Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Tue, 19 Aug 2025 17:22:04 -0500 Subject: [PATCH 7/9] style changes and more control logic tweaking --- src/ptychi/data_structures/probe.py | 18 +++++++++++ src/ptychi/reconstructors/lsqml.py | 48 ++++++++++------------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 8a93931..6e2b0de 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -516,6 +516,24 @@ def get_sparse_code_weights(self): return sparse_code + def get_sparse_code_probe_shared_weights(self): + + probe_shared = self.data[0,...] + sz = probe_shared.shape + probe_vec = torch.reshape(probe_shared, (sz[0], sz[1] * sz[2])) + sparse_code = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) + + return sparse_code[None,...] + + def get_sparse_code_probe_opr_weights(self): + + probe_opr = self.data[1:,0,...] + sz = probe_opr.shape + probe_vec = torch.reshape(probe_opr, (sz[0], sz[1] * sz[2])) + sparse_code = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) + + return sparse_code[:,None,...] + def generate(self): """Generate the probe using the sparse code, and set the generated probe to self.data. diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 422c64d..ac05c12 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -225,70 +225,61 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) ) ) - if (self.parameter_group.probe.representation == "sparse_code") and (self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): + if (self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared + and self.parameter_group.probe.options.experimental.sdl_probe_options.sparse_code_probe_shared_start >= self.current_epoch + and self.parameter_group.probe.options.experimental.sdl_probe_options.sparse_code_probe_shared_stop <= self.current_epoch + and (self.current_epoch % self.parameter_group.probe.options.experimental.sdl_probe_options.sparse_code_probe_shared_stride) == 0 + ): + + #if (self.parameter_group.probe.representation == "sparse_code") and (self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): rc = chi.shape[-1] * chi.shape[-2] n_scpm = chi.shape[-3] n_spos = chi.shape[-4] - - #====================================================================== + # sparse code update directions vs scan position and shared probe modes - obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) - delta_sparse_code = chi * obj_patches_slice_i_conj[:, None, ... ] delta_sparse_code = self.adjoint_shift_probe_update_direction(indices, delta_sparse_code, first_mode_only=True) - delta_sparse_code = torch.reshape( delta_sparse_code, ( n_spos, n_scpm, rc )) - delta_sparse_code = torch.einsum('ijk,kl->lij', delta_sparse_code, self.parameter_group.probe.dictionary_matrix.conj()) - - #=================================================== - # compute optimal step length for sparse code update - + + # compute optimal step length for sparse code update dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', self.parameter_group.probe.dictionary_matrix, delta_sparse_code) - obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) - + obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] - denom = torch.einsum('ij,jik->ik', torch.conj( obj_patches_vec ), denom) - #===== - chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True) - numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, ( n_spos, n_scpm, rc )).permute(2,0,1) numer = torch.einsum('ij,jik->ik', torch.conj( obj_patches_vec ), numer) - + # real is used to throw away small imag part due to numerical precision errors optimal_step_sparse_code = ( numer / denom ).real - - #===== - + optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code # TODO: At some point we'll want to use scan position dependent sparse codes instead of this averaged sparse code optimal_delta_sparse_code_mean_spos = ( optimal_delta_sparse_code.mean(1).T )[None, ...] - # sparse code update + # sparse code update + sparse_code2 = self.parameter_group.probe.get_sparse_code_probe_shared_weights() sparse_code = self.parameter_group.probe.get_sparse_code_weights() sparse_code = sparse_code + optimal_delta_sparse_code_mean_spos - #=========================================== # Enforce sparsity constraint on sparse code - # !!!!! MOVE THIS TO BEFORE OPTIMAL STEP CALC !!!!! - abs_sparse_code = torch.abs(sparse_code) sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) @@ -296,23 +287,16 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) # hard thresholding: sparse_code = sparse_code * (abs_sparse_code >= sel[...,None]) - #(TODO: soft thresholding option) - #============================================== # Update the new sparse code in the probe class - self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code) - #=============================================================== # Create the probe update vs scan position using the sparse code - delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, optimal_delta_sparse_code) delta_p_i = delta_p_i.permute(1, 2, 0) - delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) - delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) else: From be8ea9a8e76b1b1e7fd33190eaf71d13e3cb1c11 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 27 Aug 2025 15:49:48 -0500 Subject: [PATCH 8/9] refactor of sparse code + dictionary learning shared and OPR mode update --- src/ptychi/api/options/base.py | 43 ++- .../data_structures/opr_mode_weights.py | 264 ++++++++++-------- src/ptychi/data_structures/probe.py | 105 ++++--- src/ptychi/reconstructors/lsqml.py | 137 +++++---- 4 files changed, 298 insertions(+), 251 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 9fc8f87..ccb652c 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -542,7 +542,9 @@ class ProbeOrthogonalizeIncoherentModesOptions(FeatureOptions): method: enums.OrthogonalizationMethods = enums.OrthogonalizationMethods.SVD """The method to use for incoherent_mode orthogonalization.""" - + + sort_by_occupancy: bool = False + """Keep the probes sorted so that mode with highest occupancy is the 0th shared mode""" @dataclasses.dataclass class ProbeOrthogonalizeOPRModesOptions(FeatureOptions): @@ -648,12 +650,15 @@ def get_non_data_fields(self) -> dict: @dataclasses.dataclass class SynthesisDictLearnProbeOptions(Options): - - use_avg_spos_sparse_code: bool = True - """When computing the sparse code updates, we can either solve for - sparse codes that are scan position dependent or we can use the average - over scan positions before solving for the average sparse code.""" - + + enabled: bool = False + enabled_shared: bool = False + enabled_opr: bool = False + + thresholding_type_shared: str = 'hard' + thresholding_type_opr: str = 'hard' + """Choose between 'hard' or 'soft' thresholding.""" + dictionary_matrix: Union[ndarray, Tensor] = None """The synthesis sparse dictionary matrix; contains the basis functions that will be used to represent the probe via the sparse code weights.""" @@ -662,30 +667,18 @@ class SynthesisDictLearnProbeOptions(Options): """Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix.""" sparse_code_probe_shared: Union[ndarray, Tensor] = None - """Sparse code weights vector.""" + """Sparse code weights vector for the shared modes.""" sparse_code_probe_shared_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector sparse_code_probe.""" + the sparse code weights vector sparse_code_probe_shared.""" sparse_code_probe_opr: Union[ndarray, Tensor] = None """Sparse code weights vector for the OPRs.""" sparse_code_probe_opr_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector sparse_code_opr.""" - - sparse_code_probe_shared_start: int = torch.inf - sparse_code_probe_shared_stop: int = torch.inf - sparse_code_probe_shared_stride: int = 1 - - sparse_code_probe_opr_start: int = torch.inf - sparse_code_probe_opr_stop: int = torch.inf - sparse_code_probe_opr_stride: int = 1 - - enabled: bool = False - enabled_shared: bool = False - enabled_opr: bool = False + the sparse code weights vector sparse_code_probe_opr.""" @dataclasses.dataclass class PositionCorrectionOptions(Options): @@ -888,6 +881,12 @@ class OPRModeWeightsOptions(ParameterOptions): A separate step size for eigenmode weight update. """ + use_optimal_update: bool = False + """ + We do not compute an optimal update step for OPR weights and eigenmodes + using the default method; this add the option to do this. + """ + def check(self, options: "task_options.PtychographyTaskOptions"): super().check(options) if self.optimizable: diff --git a/src/ptychi/data_structures/opr_mode_weights.py b/src/ptychi/data_structures/opr_mode_weights.py index 987e4e6..8d90aac 100644 --- a/src/ptychi/data_structures/opr_mode_weights.py +++ b/src/ptychi/data_structures/opr_mode_weights.py @@ -97,11 +97,11 @@ def intensity_variation_optimization_enabled(self, epoch: int): def update_variable_probe( self, probe: "Probe", + adjoint_shift_probe_update_direction, # what do I do for type hint here? indices: Tensor, chi: Tensor, delta_p_i: Tensor, delta_p_hat: Tensor, - probe_current_slice: Tensor, obj_patches: Tensor, current_epoch: int, probe_mode_index: Optional[int] = None, @@ -118,8 +118,14 @@ def update_variable_probe( probe.optimization_enabled(current_epoch) or (self.eigenmode_weight_optimization_enabled(current_epoch)) ): - self.update_opr_probe_modes_and_weights( - probe, indices, chi, delta_p_i, delta_p_hat, probe_current_slice, obj_patches, current_epoch + self.update_opr_probe_modes_and_weights(probe, + adjoint_shift_probe_update_direction, + indices, + chi, + delta_p_i, + delta_p_hat, + obj_patches, + current_epoch ) if self.intensity_variation_optimization_enabled(current_epoch): @@ -135,25 +141,23 @@ def update_variable_probe( def update_opr_probe_modes_and_weights( self, probe: "Probe", + adjoint_shift_probe_update_direction, # what do I do for type hint here? indices: Tensor, chi: Tensor, delta_p_i: Tensor, delta_p_hat: Tensor, - probe_current_slice: Tensor, obj_patches: Tensor, current_epoch: int, ): """ Update the eigenmodes of the first incoherent mode of the probe, and update the OPR mode weights. - This implementation is adapted from PtychoShelves code (update_variable_probe.m) and has some - differences from Eq. 31 of Odstrcil (2018). + The default (for self.options.use_optimal_update = False) implementation below is adapted from + PtychoShelves code (update_variable_probe.m) and has some differences from Eq. 31 of Odstrcil (2018). """ probe_data = probe.data weights_data = self.data - # print( torch.linalg.norm(probe_data,dim=(-1,-2)), torch.linalg.norm( weights_data,dim=(0)) ) - batch_size = len(delta_p_i) n_points_total = self.n_scan_points @@ -162,127 +166,165 @@ def update_opr_probe_modes_and_weights( if batch_size == 1: return - update_eigenmode = probe.optimization_enabled(current_epoch) - - if (update_eigenmode - and probe.representation == "sparse_code" - and probe.options.experimental.sdl_probe_options.enabled_opr - and probe.options.experimental.sdl_probe_options.sparse_code_probe_opr_start >= current_epoch - and probe.options.experimental.sdl_probe_options.sparse_code_probe_opr_stop <= current_epoch - and (current_epoch % probe.options.experimental.sdl_probe_options.sparse_code_probe_opr_stride) == 0 - ): - - sz = delta_p_i.shape - weights_ALL = self.get_weights(indices)[:,1:] - - #print( torch.linalg.cond( weights_ALL ) ) - - probe_current_slice_vec = torch.reshape( probe_current_slice[:,0,...], (sz[0], sz[-2] * sz[-1]) ).T - sparse_code_opr = probe.dictionary_matrix_pinv @ probe_current_slice_vec - - # need to use regularization here because weights_ALL has terrible condition number - if torch.linalg.cond( weights_ALL ) > 5: - - A = weights_ALL.to(torch.complex64) - b = sparse_code_opr.T - - lambda_value = 1e-3 # Tikhonov regularization parameter - L = torch.eye(A.shape[1]) # L matrix (identity for standard Tikhonov) - - A_reg = torch.cat((A, lambda_value * L), dim=0) - b_reg = torch.cat((b, torch.zeros(L.shape[0], b.shape[1])), dim=0) + update_eigenmode = probe.optimization_enabled(current_epoch) # why is this needed again? To even get into this function, we need this to already be true? + update_eigenmode_weights = self.eigenmode_weight_optimization_enabled(current_epoch) - sparse_code_opr = torch.linalg.lstsq(A_reg, b_reg).solution.T - else: - sparse_code_opr = sparse_code_opr @ torch.linalg.pinv( weights_ALL.to(torch.complex64) ).T + if self.options.use_optimal_update: - #print( torch.linalg.cond( sparse_code_opr ) ) - #print( torch.linalg.norm( sparse_code_opr, dim=0 ) ) + rc = obj_patches.shape[-2] * obj_patches.shape[-1] + n_spos = obj_patches.shape[0] - # enforce sparsity constraint - - sparsity_nnz = 195 - abs_sparse_code_opr = torch.abs(sparse_code_opr) + U = probe_data[1:, 0, ...] - # abs_sparse_code_opr_sorted = torch.sort(abs_sparse_code_opr, dim=0, descending=True) - # sel = abs_sparse_code_opr_sorted[0][sparsity_nnz,:] - sel = torch.sort(abs_sparse_code_opr, dim=0, descending=True)[0][sparsity_nnz,:] - sparse_code_opr_mask = (abs_sparse_code_opr >= sel) - - sparse_code_opr = sparse_code_opr * sparse_code_opr_mask # Hard Thresholding - #sparse_code_opr = ( abs_sparse_code_opr - sel) * sparse_code_opr_mask * torch.exp(1j * torch.angle(sparse_code_opr)) # Soft Thresholding + Ws = (weights_data[ indices, 1:]).to(torch.complex64) - # SET THE NEW OPR SPARSE CODES + Tsconj_chi = (obj_patches[:,0,...].conj() * chi[:,0,...]) + Tsconj_chi = adjoint_shift_probe_update_direction( indices, Tsconj_chi[:,None,...], first_mode_only=True) - probe.options.experimental.sdl_probe_options.set_sparse_code_probe_opr( sparse_code_opr ) + chi = adjoint_shift_probe_update_direction( indices, chi, first_mode_only=True) + U = torch.reshape(U, (U.shape[0], rc)) + chi_vec = torch.reshape(chi[:,0,...], (n_spos, rc)) + Ts = torch.reshape(obj_patches[:,0,...], (n_spos, rc)) + Tsconj_chi = torch.reshape(Tsconj_chi[:,0,...], (n_spos, rc)).T + # Optimal OPR weight updates - - # Use updated OPR sparse code to generate the dense representation OPR modes - eigenmodes_updated = probe.dictionary_matrix @ sparse_code_opr + if update_eigenmode_weights: + + delta_Ws = -2 * torch.real(U.conj() @ Tsconj_chi).to(torch.complex64) - eigenmodes_updated = torch.reshape(eigenmodes_updated.T, - (eigenmodes_updated.shape[-1], sz[-2], sz[-1])) - - #================================================ - # rescale so that all OPR modes have correct norm - - eigenmodes_current_scaling = torch.linalg.norm(eigenmodes_updated, axis=(-1,-2), keepdims=True) - eigenmodes_updated_scaling = torch.sqrt(torch.prod(torch.tensor(eigenmodes_updated.shape[-2:]))) - - scaling_ratio = eigenmodes_updated_scaling / eigenmodes_current_scaling - - eigenmodes_updated = eigenmodes_updated * scaling_ratio + Ts_U_deltaWs = Ts.T * (U.T @ delta_Ws) + numer = torch.sum(torch.real(chi_vec * Ts_U_deltaWs.H)) + denom = torch.sum(torch.real( Ts_U_deltaWs.conj() * Ts_U_deltaWs )) + optimal_step_deltaWs = self.options.update_relaxation * (numer / denom) - #===== + Ws = (Ws + optimal_step_deltaWs * delta_Ws.T) + + if (probe.representation == "sparse_code" + and probe.options.experimental.sdl_probe_options.enabled_opr): + + # Optimal sparse code OPR mode updates + + delta_U = -1 * Tsconj_chi @ Ws - w = 1 - 0e-6 - probe_data[1:, 0, :, :] = w * eigenmodes_updated + ( 1 - w ) * probe_data[:, 0, :, :] - - probe.set_data(probe_data) - update_eigenmode = False + delta_sparse_code_probe_opr = probe.dictionary_matrix.H @ delta_U + + Gs = probe.dictionary_matrix @ delta_sparse_code_probe_opr @ Ws.T + TsHGsH = Ts.H * Gs.conj() + numer = torch.sum( torch.real(TsHGsH * chi_vec.T)) + denom = torch.sum( torch.real(TsHGsH * TsHGsH.conj())) + optimal_step_sparse_code_probe_opr = probe.options.eigenmode_update_relaxation * (numer / denom) + + sparse_code_probe_opr = probe.get_sparse_code_probe_opr_weights() + + optimal_sparse_code_probe_opr = (sparse_code_probe_opr + + optimal_step_sparse_code_probe_opr * delta_sparse_code_probe_opr.T) + + # Enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(optimal_sparse_code_probe_opr) + abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) + sel = abs_sparse_code_sorted[0][:, probe.sparse_code_probe_nnz] + sparse_code_mask = (abs_sparse_code >= sel[:,None]) + + # Hard or Soft thresholding + if probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'hard': + optimal_sparse_code_probe_opr = optimal_sparse_code_probe_opr * sparse_code_mask + elif probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'soft': + optimal_sparse_code_probe_opr = ( abs_sparse_code - sel[:,None] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_sparse_code_probe_opr)) + + probe.set_sparse_code_probe_opr(optimal_sparse_code_probe_opr) + + # Back to dense OPR representation + U = (probe.dictionary_matrix @ optimal_sparse_code_probe_opr.T).T + + # the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc)) + U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None] + + U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1])) + + probe_data[1:, 0, :, :] = U + weights_data[indices, 1:] = Ws.real + + # DELETE THIS FOR FINAL MERGING + # DELETE THIS FOR FINAL MERGING + + # Test the rank of the new scan position dependent probe: + + # probe_data_TEST = torch.reshape(probe_data[:,0,...], (probe_data.shape[0], probe_data.shape[-1] * probe_data.shape[-2])) + # Z1 = torch.sum(probe_data[:, 0, :, :][None,...] * weights_data[indices][...,None,None], 1) + # Z1 = torch.reshape(Z1, (Z1.shape[0], Z1.shape[1] * Z1.shape[2])) + # Z2 = probe_data_TEST.T @ weights_data[indices, :].T.to(torch.complex64) + # print( torch.linalg.matrix_rank(Z1) ) + # print( torch.linalg.matrix_rank(Z2) ) + + # DELETE THIS FOR FINAL MERGING + # DELETE THIS FOR FINAL MERGING + + else: + + # Optimal dense OPR mode updates: - # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. - relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation - relax_v = self.options.update_relaxation - # Shape of delta_p_i: (batch_size, n_probe_modes, h, w) - # Use only the first incoherent mode - delta_p_i = delta_p_i[:, 0, :, :] - delta_p_hat = delta_p_hat[0, :, :] - residue_update = delta_p_i - delta_p_hat - - # Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode. - for i_opr_mode in range(1, probe.n_opr_modes): - # Just take the first incoherent mode. - eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode) - weights_i = self.get_weights(indices)[:, i_opr_mode] - eigenmode_i, weights_i = self._update_first_eigenmode_and_weight( - residue_update, - eigenmode_i, - weights_i, - relax_u, - relax_v, - obj_patches, - chi, - update_eigenmode = update_eigenmode, - update_weights=self.eigenmode_weight_optimization_enabled(current_epoch), - ) + delta_U = -1 * Tsconj_chi @ Ws - # Project residue on this eigenmode, then subtract it. - if i_opr_mode < probe.n_opr_modes - 1: - residue_update = residue_update - pmath.project( - residue_update, eigenmode_i, dim=(-2, -1) + Ts_deltaU_Ws = Ts.T * (delta_U @ Ws.T) + numer = torch.sum(torch.real(chi_vec * Ts_deltaU_Ws.H)) + denom = torch.sum(torch.real( Ts_deltaU_Ws.conj() * Ts_deltaU_Ws )) + optimal_step_deltaU = probe.options.eigenmode_update_relaxation * (numer / denom) + + U = U + optimal_step_deltaU * delta_U.T + + # the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc)) + U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None] + + U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1])) + + probe_data[1:, 0, :, :] = U + weights_data[indices, 1:] = Ws.real + + else: + + # Ptychoshelves method for OPR updates + + # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. + relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation + relax_v = self.options.update_relaxation + # Shape of delta_p_i: (batch_size, n_probe_modes, h, w) + # Use only the first incoherent mode + delta_p_i = delta_p_i[:, 0, :, :] + delta_p_hat = delta_p_hat[0, :, :] + residue_update = delta_p_i - delta_p_hat + + # Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode. + for i_opr_mode in range(1, probe.n_opr_modes): + # Just take the first incoherent mode. + eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode) + weights_i = self.get_weights(indices)[:, i_opr_mode] + eigenmode_i, weights_i = self._update_first_eigenmode_and_weight( + residue_update, + eigenmode_i, + weights_i, + relax_u, + relax_v, + obj_patches, + chi, + update_eigenmode=update_eigenmode, + update_weights=self.eigenmode_weight_optimization_enabled(current_epoch), ) - probe_data[i_opr_mode, 0, :, :] = eigenmode_i - weights_data[indices, i_opr_mode] = weights_i + # Project residue on this eigenmode, then subtract it. + if i_opr_mode < probe.n_opr_modes - 1: + residue_update = residue_update - pmath.project( + residue_update, eigenmode_i, dim=(-2, -1) + ) - # print( torch.linalg.norm(probe_data,dim=(-1,-2)), torch.linalg.norm( weights_data,dim=(0)) ) + probe_data[i_opr_mode, 0, :, :] = eigenmode_i + weights_data[indices, i_opr_mode] = weights_i - if update_eigenmode and not (probe.representation == "sparse_code"): - probe.set_data(probe_data) - if self.eigenmode_weight_optimization_enabled(current_epoch): + if update_eigenmode: + probe.set_data(probe_data) + + if update_eigenmode_weights: self.set_data(weights_data) @timer() diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 6e2b0de..73345da 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -221,7 +221,12 @@ def constrain_incoherent_modes_orthogonality(self): return probe = self.data - + + if self.options.orthogonalize_incoherent_modes.sort_by_occupancy: + shared_occupancy = torch.sum(torch.abs(probe[0,...])**2,(-2,-1)) / torch.sum(torch.abs(probe[0,...])**2) + shared_occupancy = torch.sort(shared_occupancy, dim=0, descending=True) + probe[0,...] = probe[ 0, shared_occupancy[1],...] + norm_first_mode_orig = pmath.norm(probe[0, 0], dim=(-2, -1)) if self.orthogonalize_incoherent_modes_method == "gs": @@ -463,42 +468,22 @@ def __init__(self, name = "probe", options = None, *args, **kwargs): super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs) dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary() + self.register_buffer("dictionary_matrix", dictionary_matrix) self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv) - - sparse_code_probe_shared_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, dtype=torch.uint32 ) - self.register_buffer("sparse_code_probe_nnz", sparse_code_probe_shared_nnz ) + sparse_code_probe_shared_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, dtype=torch.uint32 ) sparse_code_probe_opr_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_nnz, dtype=torch.uint32 ) + + self.register_buffer("sparse_code_probe_nnz", sparse_code_probe_shared_nnz ) self.register_buffer("sparse_code_opr_nnz", sparse_code_probe_opr_nnz ) - sparse_code_probe_shared = self.get_sparse_code_weights() - self.register_parameter("sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)) + sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() + sparse_code_probe_opr = self.get_sparse_code_probe_opr_weights() - sparse_code_probe_opr = self.get_sparse_code_weights() + self.register_parameter("sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)) self.register_parameter("sparse_code_probe_opr", torch.nn.Parameter(sparse_code_probe_opr)) - - use_avg_spos_sparse_code = self.options.experimental.sdl_probe_options.use_avg_spos_sparse_code - self.register_buffer("use_avg_spos_sparse_code", torch.tensor(use_avg_spos_sparse_code, dtype=torch.bool)) - - sparse_code_probe_shared_start = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_start ) - self.register_buffer("sparse_code_probe_shared_start", sparse_code_probe_shared_start ) - - sparse_code_probe_shared_stop = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_stop ) - self.register_buffer("sparse_code_probe_shared_stop", sparse_code_probe_shared_stop ) - - sparse_code_probe_shared_stride = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_stride ) - self.register_buffer("sparse_code_probe_shared_stride", sparse_code_probe_shared_stride ) - sparse_code_probe_opr_start = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_start ) - self.register_buffer("sparse_code_probe_opr_start", sparse_code_probe_opr_start ) - - sparse_code_probe_opr_stop = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_stop ) - self.register_buffer("sparse_code_probe_opr_stop", sparse_code_probe_opr_stop ) - - sparse_code_probe_opr_stride = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_stride ) - self.register_buffer("sparse_code_probe_opr_stride", sparse_code_probe_opr_stride ) - self.build_optimizer() def get_dictionary(self): @@ -508,31 +493,31 @@ def get_dictionary(self): return dictionary_matrix, dictionary_matrix_pinv - def get_sparse_code_weights(self): + def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions ): - sz = self.data.shape - probe_vec = torch.reshape(self.data, (sz[0], sz[1], sz[2] * sz[3])) - sparse_code = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) + sz = probe_vs_scanpositions.shape + probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2]*sz[3])) + sparse_code_vs_scanpositions = torch.einsum('ij,klj->ikl', self.dictionary_matrix_pinv, probe_vec) - return sparse_code + return sparse_code_vs_scanpositions def get_sparse_code_probe_shared_weights(self): probe_shared = self.data[0,...] sz = probe_shared.shape - probe_vec = torch.reshape(probe_shared, (sz[0], sz[1] * sz[2])) - sparse_code = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) + probe_vec = torch.reshape(probe_shared, (sz[0], sz[1]*sz[2])) + sparse_code_probe_shared = self.dictionary_matrix_pinv @ probe_vec.T - return sparse_code[None,...] + return sparse_code_probe_shared.T def get_sparse_code_probe_opr_weights(self): probe_opr = self.data[1:,0,...] sz = probe_opr.shape - probe_vec = torch.reshape(probe_opr, (sz[0], sz[1] * sz[2])) - sparse_code = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec) + probe_vec = torch.reshape(probe_opr, (sz[0], sz[1]*sz[2])) + sparse_code_probe_opr = self.dictionary_matrix_pinv @ probe_vec.T - return sparse_code[:,None,...] + return sparse_code_probe_opr.T def generate(self): """Generate the probe using the sparse code, and set the @@ -544,17 +529,49 @@ def generate(self): A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - if (self.options.experimental.sdl_probe_options.enabled_shared): + if (self.options.experimental.sdl_probe_options.enabled_shared + and self.options.experimental.sdl_probe_options.enabled_opr): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T + probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T - # GENERATE OPR MODES FOR probe[1:,...] HERE USING self.sparse_code_probe_opr + probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) + probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) + + self.set_data(probe) + + elif (self.options.experimental.sdl_probe_options.enabled_shared + and not self.options.experimental.sdl_probe_options.enabled_opr): - probe = torch.einsum('ij,klj->kli', self.dictionary_matrix, self.sparse_code_probe_shared) - probe = torch.reshape( probe, *[self.data.shape] ) + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T + + probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) + probe[1:,0,...] = self.data[1:,0,...] + + self.set_data(probe) + + elif (self.options.experimental.sdl_probe_options.enabled_opr + and not self.options.experimental.sdl_probe_options.enabled_shared): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T + + probe[0,...] = self.data[0,...] + probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) + self.set_data(probe) else: - probe = self.probe.data + + probe = self.data return probe diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index ac05c12..89919c0 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -225,80 +225,69 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) ) ) - if (self.parameter_group.probe.representation == "sparse_code" - and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared - and self.parameter_group.probe.options.experimental.sdl_probe_options.sparse_code_probe_shared_start >= self.current_epoch - and self.parameter_group.probe.options.experimental.sdl_probe_options.sparse_code_probe_shared_stop <= self.current_epoch - and (self.current_epoch % self.parameter_group.probe.options.experimental.sdl_probe_options.sparse_code_probe_shared_stride) == 0 - ): - - #if (self.parameter_group.probe.representation == "sparse_code") and (self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): - - rc = chi.shape[-1] * chi.shape[-2] - n_scpm = chi.shape[-3] - n_spos = chi.shape[-4] - - # sparse code update directions vs scan position and shared probe modes - obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) - delta_sparse_code = chi * obj_patches_slice_i_conj[:, None, ... ] - delta_sparse_code = self.adjoint_shift_probe_update_direction(indices, delta_sparse_code, first_mode_only=True) - delta_sparse_code = torch.reshape( delta_sparse_code, - ( n_spos, n_scpm, rc )) - delta_sparse_code = torch.einsum('ijk,kl->lij', - delta_sparse_code, - self.parameter_group.probe.dictionary_matrix.conj()) - - # compute optimal step length for sparse code update - dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', - self.parameter_group.probe.dictionary_matrix, - delta_sparse_code) - - obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) - denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] - denom = torch.einsum('ij,jik->ik', - torch.conj( obj_patches_vec ), - denom) - - chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True) - numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, - ( n_spos, n_scpm, rc )).permute(2,0,1) - numer = torch.einsum('ij,jik->ik', - torch.conj( obj_patches_vec ), - numer) - - # real is used to throw away small imag part due to numerical precision errors - optimal_step_sparse_code = ( numer / denom ).real - - optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code - - # TODO: At some point we'll want to use scan position dependent sparse codes instead of this averaged sparse code - optimal_delta_sparse_code_mean_spos = ( optimal_delta_sparse_code.mean(1).T )[None, ...] - - # sparse code update - sparse_code2 = self.parameter_group.probe.get_sparse_code_probe_shared_weights() - sparse_code = self.parameter_group.probe.get_sparse_code_weights() - sparse_code = sparse_code + optimal_delta_sparse_code_mean_spos - - # Enforce sparsity constraint on sparse code - abs_sparse_code = torch.abs(sparse_code) - sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) - - sel = sparse_code_sorted[0][...,self.parameter_group.probe.sparse_code_probe_nnz] - - # hard thresholding: - sparse_code = sparse_code * (abs_sparse_code >= sel[...,None]) - #(TODO: soft thresholding option) - - # Update the new sparse code in the probe class - self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code) - - # Create the probe update vs scan position using the sparse code - delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, - optimal_delta_sparse_code) - delta_p_i = delta_p_i.permute(1, 2, 0) - delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) - delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + # TODO: move this to SynthesisDictLearnProbe class methods, so it can be used in rPIE as well + if (self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): + + rc = chi.shape[-1] * chi.shape[-2] + n_scpm = chi.shape[-3] + n_spos = chi.shape[-4] + + # sparse code update directions vs scan position and shared probe modes + obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) + delta_sparse_code = self.adjoint_shift_probe_update_direction(indices, chi * obj_patches_slice_i_conj[:, None, ... ], first_mode_only=True) + delta_sparse_code = torch.reshape( delta_sparse_code, + ( n_spos, n_scpm, rc )) + delta_sparse_code = torch.einsum('ijk,kl->lij', + delta_sparse_code, + self.parameter_group.probe.dictionary_matrix.conj()) + + # compute optimal step length for sparse code update + dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', + self.parameter_group.probe.dictionary_matrix, + delta_sparse_code) + + obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) + denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] + denom = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + denom) + + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True) + numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, + ( n_spos, n_scpm, rc )).permute(2,0,1) + numer = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + numer) + + # real is used to throw away small imag part due to numerical precision errors + optimal_step_sparse_code = ( numer / denom ).real + optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code + + # Enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(optimal_delta_sparse_code) + abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + + sel = abs_sparse_code_sorted[0][self.parameter_group.probe.sparse_code_probe_nnz, ...] + sparse_code_mask = (abs_sparse_code >= sel[None,...]) + + # Hard or Soft thresholding + if self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': + optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask + elif self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': + optimal_delta_sparse_code = ( abs_sparse_code - sel[None,...] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) + + # update the shared probe sparse codes using the average over scan positions + sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() + sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T + self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) + + delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, + optimal_delta_sparse_code).permute(1, 2, 0) + delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) + delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + else: # Calculate probe update direction. delta_p_i_unshifted = self._calculate_probe_update_direction( @@ -316,11 +305,11 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): self.parameter_group.opr_mode_weights.update_variable_probe( self.parameter_group.probe, + self.adjoint_shift_probe_update_direction, indices, chi, delta_p_i, delta_p_hat, - probe_current_slice, obj_patches, self.current_epoch, probe_mode_index=0, From 42b5187db4c0533c7295b6d6101873b4a9d48dd4 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 27 Aug 2025 16:33:22 -0500 Subject: [PATCH 9/9] updated a doc string --- src/ptychi/api/options/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index ccb652c..660dde7 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -671,14 +671,14 @@ class SynthesisDictLearnProbeOptions(Options): sparse_code_probe_shared_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector sparse_code_probe_shared.""" + the SHARED sparse code weights vector sparse_code_probe_shared.""" sparse_code_probe_opr: Union[ndarray, Tensor] = None """Sparse code weights vector for the OPRs.""" sparse_code_probe_opr_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector sparse_code_probe_opr.""" + the OPR sparse code weights vector sparse_code_probe_opr.""" @dataclasses.dataclass class PositionCorrectionOptions(Options):