11from typing import Optional
22import numpy as np
3+ from pyxalign import LaminographyAlignmentTask
34from pyxalign .alignment .cross_correlation import CrossCorrelationAligner
45from pyxalign .api import enums
56from pyxalign .api .options .device import DeviceOptions
67from pyxalign .api .options .projections import ProjectionOptions
78from pyxalign .api .options .task import AlignmentTaskOptions
9+ from pyxalign .data_structures .task import run_projection_matching
810from pyxalign .data_structures .xrf_projections import XRFProjections
911from pyxalign .timing .timer_utils import clear_timer_globals
1012from pyxalign .api .types import r_type
1113
1214
1315class XRFTask :
14- _primary_channel : str
15- projections_dict : dict [str , XRFProjections ] = {}
16-
1716 def __init__ (
1817 self ,
1918 xrf_array_dict : dict [str , np .ndarray ],
2019 angles : np .ndarray ,
2120 scan_numbers : np .ndarray ,
2221 projection_options : ProjectionOptions ,
23- task_options : AlignmentTaskOptions ,
22+ alignment_options : AlignmentTaskOptions ,
2423 primary_channel : str ,
2524 ):
25+ self .pma_object = None
26+ self .pma_gui_list = []
27+ self ._primary_channel : str
28+ self .projections_dict : dict [str , XRFProjections ] = {}
2629 # self.angles = angles
2730 # self.scan_numbers = scan_numbers
2831 self .channels = xrf_array_dict .keys ()
2932 self .projection_options = projection_options
30- self .task_options = task_options
33+ self .alignment_options = alignment_options
3134 self ._primary_channel = primary_channel
3235 self .create_xrf_projections_object (xrf_array_dict , angles , scan_numbers )
3336 self ._center_of_rotation = self .projections_dict [self ._primary_channel ].center_of_rotation
@@ -55,6 +58,10 @@ def create_xrf_projections_object(
5558 # transform_tracker=
5659 )
5760
61+ # @property
62+ # def phase_projections(self):
63+ # return self.projections_dict[self._primary_channel]
64+
5865 @property
5966 def angles (self ):
6067 return self .projections_dict [self ._primary_channel ].angles
@@ -87,26 +94,66 @@ def center_of_rotation(self):
8794 @center_of_rotation .setter
8895 def center_of_rotation (self , center_of_rotation : np .ndarray ):
8996 self ._center_of_rotation = center_of_rotation
90- for channel , projections in self .projections_dict .items ():
97+ for _ , projections in self .projections_dict .items ():
9198 projections .center_of_rotation = self ._center_of_rotation * 1
9299
93100 def apply_staged_shift_to_all_channels (self , device_options : Optional [DeviceOptions ] = None ):
94- for channel , projections in self .projections_dict .items ():
101+ for _ , projections in self .projections_dict .items ():
95102 projections .apply_staged_shift (device_options )
96103
97104 def drop_projections_from_all_channels (self , remove_idx : list [int ]):
98- for channel , projections in self .projections_dict .items ():
105+ for _ , projections in self .projections_dict .items ():
99106 projections .drop_projections (remove_idx = remove_idx )
100107
101108 def pin_all_arrays (self ):
102- for channel , projections in self .projections_dict .items ():
109+ for _ , projections in self .projections_dict .items ():
103110 projections .pin_arrays ()
104111
112+ def get_projection_matching_shift (
113+ self , initial_shift : Optional [np .ndarray ] = None
114+ ) -> np .ndarray :
115+ # clear existing astra objects
116+ if self .pma_object is not None :
117+ if hasattr (self .pma_object , "aligned_projections" ):
118+ self .pma_object .aligned_projections .volume .clear_astra_objects ()
119+
120+ # reset timers
121+ clear_timer_globals ()
122+
123+ # close old gui windows
124+ if self .alignment_options .projection_matching .interactive_viewer .close_old_windows :
125+ self .clear_pma_gui_list ()
126+ else :
127+ self .pma_gui_list += [self .pma_object .gui ]
128+
129+ # run the pma algorithm
130+ self .pma_object , shift = run_projection_matching (
131+ self .projections_dict [self .primary_channel ],
132+ initial_shift ,
133+ self .alignment_options .projection_matching ,
134+ )
135+
136+ # Save the resulting alignment shift
137+ for _ , projections in self .projections_dict .items ():
138+ projections .shift_manager .stage_shift (
139+ shift = shift ,
140+ function_type = enums .ShiftType .FFT ,
141+ alignment_options = self .alignment_options .projection_matching ,
142+ )
143+ print ("Projection-matching shifts stored in shift_manager" )
144+
145+ return shift
146+
147+ def clear_pma_gui_list (self ):
148+ for gui in self .pma_gui_list :
149+ gui .close ()
150+ self .pma_gui_list = []
151+
105152 def get_cross_correlation_shift (self , illum_sum : np .ndarray = None ):
106153 clear_timer_globals ()
107154 self .cross_correlation_aligner = CrossCorrelationAligner (
108155 projections = self .projections_dict [self ._primary_channel ],
109- options = self .task_options .cross_correlation ,
156+ options = self .alignment_options .cross_correlation ,
110157 )
111158 # Placeholder for actual illum_sum
112159 if illum_sum is None :
@@ -121,7 +168,7 @@ def get_cross_correlation_shift(self, illum_sum: np.ndarray = None):
121168 projections .shift_manager .stage_shift (
122169 shift = shift ,
123170 function_type = enums .ShiftType .CIRC ,
124- alignment_options = self .task_options .cross_correlation ,
171+ alignment_options = self .alignment_options .cross_correlation ,
125172 )
126173 projections .plot_staged_shift ("Cross-correlation Shift" )
127174 print ("Cross-correlation shift stored in shift_manager" )
0 commit comments