Skip to content

Commit d77e6cb

Browse files
committed
TEST: finished adding xrf_data_1 tester
1 parent c4dc1ab commit d77e6cb

File tree

5 files changed

+157
-71
lines changed

5 files changed

+157
-71
lines changed

src/pyxalign/alignment/projection_matching.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import partial
22
from typing import Callable, Optional
3-
import matplotlib
43
import numpy as np
54
import cupy as cp
65
import copy
@@ -20,15 +19,14 @@
2019
import pyxalign.image_processing as ip
2120
import pyxalign.api.maps as maps
2221
from pyxalign.api.enums import DeviceType, MemoryConfig
23-
from pyxalign.api.options.alignment import ProjectionMatchingOptions, ProjectionMatchingPlotOptions
22+
from pyxalign.api.options.alignment import ProjectionMatchingOptions
2423
import pyxalign.gpu_utils as gutils
2524
from pyxalign.api.types import ArrayType, r_type
2625
from pyxalign.gpu_wrapper import device_handling_wrapper
2726
from IPython.display import clear_output
2827
import matplotlib.pyplot as plt
2928
from tqdm import tqdm
3029
import astra
31-
import time
3230

3331

3432
class ProjectionMatchingAligner(Aligner):

src/pyxalign/data_structures/task.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from PyQt5.QtWidgets import QApplication
55

66
from pyxalign import gpu_utils
7+
from pyxalign.api.options.alignment import ProjectionMatchingOptions
78
from pyxalign.data_structures.projections import (
89
ComplexProjections,
910
PhaseProjections,
@@ -37,7 +38,7 @@ def __init__(
3738
self.complex_projections = complex_projections
3839
self.phase_projections = phase_projections
3940
self.pma_object: ProjectionMatchingAligner = None
40-
self.pma_GUI_list: list[ProjectionMatchingViewer] = []
41+
self.pma_gui_list: list[ProjectionMatchingViewer] = []
4142

4243
def get_cross_correlation_shift(
4344
self,
@@ -68,46 +69,40 @@ def get_cross_correlation_shift(
6869
projections.plot_staged_shift("Cross-correlation Shift")
6970
print("Cross-correlation shift stored in shift_manager")
7071

71-
def get_projection_matching_shift(self, initial_shift: Optional[np.ndarray] = None):
72+
def get_projection_matching_shift(self, initial_shift: Optional[np.ndarray] = None) -> np.ndarray:
73+
# clear existing astra objects
7274
if self.pma_object is not None:
7375
if hasattr(self.pma_object, "aligned_projections"):
74-
# Clear old astra objects
7576
self.pma_object.aligned_projections.volume.clear_astra_objects()
76-
77+
78+
# reset timers
7779
clear_timer_globals()
78-
# Initialize the projection-matching alignment object
79-
self.pma_object = ProjectionMatchingAligner(
80-
self.phase_projections, self.options.projection_matching
81-
)
80+
81+
# close old gui windows
8282
if self.options.projection_matching.interactive_viewer.close_old_windows:
83-
self.clear_pma_GUI_list()
84-
try:
85-
if self.pma_object.options.interactive_viewer.update.enabled:
86-
# Run PMA algorithm
87-
shift = self.pma_object.run_with_GUI(initial_shift=initial_shift)
88-
# Store the QWidget in a list so the window remains open even if
89-
# another PMA loop is started
90-
self.pma_GUI_list += [self.pma_object.gui] # uncomment later
91-
# Close window
92-
# self.pma_object.gui.close() # I think adding this helped, or removing the list helped.
93-
else:
94-
# Run PMA algorithm
95-
shift = self.pma_object.run(initial_shift=initial_shift)
96-
except (Exception, KeyboardInterrupt):
97-
shift = self.pma_object.total_shift * self.pma_object.scale
98-
finally:
99-
# Store the result in the ShiftManager object
100-
self.phase_projections.shift_manager.stage_shift(
101-
shift=shift,
102-
function_type=enums.ShiftType.FFT,
103-
alignment_options=self.options.projection_matching,
104-
)
105-
print("Projection-matching shift stored in shift_manager")
83+
self.clear_pma_gui_list()
84+
else:
85+
self.pma_gui_list += [self.pma_object.gui]
86+
87+
# run the pma algorithm
88+
self.pma_object, shift = run_projection_matching(
89+
self.phase_projections, initial_shift, self.options.projection_matching
90+
)
10691

107-
def clear_pma_GUI_list(self):
108-
for gui in self.pma_GUI_list:
92+
# Store the result in the ShiftManager object
93+
self.phase_projections.shift_manager.stage_shift(
94+
shift=shift,
95+
function_type=enums.ShiftType.FFT,
96+
alignment_options=self.options.projection_matching,
97+
)
98+
print("Projection-matching shift stored in shift_manager")
99+
100+
return shift
101+
102+
def clear_pma_gui_list(self):
103+
for gui in self.pma_gui_list:
109104
gui.close()
110-
self.pma_GUI_list = []
105+
self.pma_gui_list = []
111106

112107
def get_complex_projection_masks(self, enable_plotting: bool = False):
113108
clear_timer_globals()
@@ -150,3 +145,23 @@ def launch_viewer(self):
150145
app = QApplication.instance() or QApplication([])
151146
self.gui = TaskViewer(self)
152147
self.gui.show()
148+
149+
150+
def run_projection_matching(
151+
phase_projections: PhaseProjections,
152+
initial_shift: np.ndarray,
153+
projection_matching_options: ProjectionMatchingOptions,
154+
) -> tuple[ProjectionMatchingAligner, np.ndarray]:
155+
# Initialize the projection-matching alignment object
156+
pma_object = ProjectionMatchingAligner(phase_projections, projection_matching_options)
157+
try:
158+
if pma_object.options.interactive_viewer.update.enabled:
159+
# Run PMA algorithm
160+
shift = pma_object.run_with_GUI(initial_shift=initial_shift)
161+
else:
162+
# Run PMA algorithm
163+
shift = pma_object.run(initial_shift=initial_shift)
164+
except (Exception, KeyboardInterrupt):
165+
shift = pma_object.total_shift * pma_object.scale
166+
finally:
167+
return pma_object, shift

src/pyxalign/data_structures/xrf_task.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,36 @@
11
from typing import Optional
22
import numpy as np
3+
from pyxalign import LaminographyAlignmentTask
34
from pyxalign.alignment.cross_correlation import CrossCorrelationAligner
45
from pyxalign.api import enums
56
from pyxalign.api.options.device import DeviceOptions
67
from pyxalign.api.options.projections import ProjectionOptions
78
from pyxalign.api.options.task import AlignmentTaskOptions
9+
from pyxalign.data_structures.task import run_projection_matching
810
from pyxalign.data_structures.xrf_projections import XRFProjections
911
from pyxalign.timing.timer_utils import clear_timer_globals
1012
from pyxalign.api.types import r_type
1113

1214

1315
class XRFTask:
14-
_primary_channel: str
15-
projections_dict: dict[str, XRFProjections] = {}
16-
1716
def __init__(
1817
self,
1918
xrf_array_dict: dict[str, np.ndarray],
2019
angles: np.ndarray,
2120
scan_numbers: np.ndarray,
2221
projection_options: ProjectionOptions,
23-
task_options: AlignmentTaskOptions,
22+
alignment_options: AlignmentTaskOptions,
2423
primary_channel: str,
2524
):
25+
self.pma_object = None
26+
self.pma_gui_list = []
27+
self._primary_channel: str
28+
self.projections_dict: dict[str, XRFProjections] = {}
2629
# self.angles = angles
2730
# self.scan_numbers = scan_numbers
2831
self.channels = xrf_array_dict.keys()
2932
self.projection_options = projection_options
30-
self.task_options = task_options
33+
self.alignment_options = alignment_options
3134
self._primary_channel = primary_channel
3235
self.create_xrf_projections_object(xrf_array_dict, angles, scan_numbers)
3336
self._center_of_rotation = self.projections_dict[self._primary_channel].center_of_rotation
@@ -55,6 +58,10 @@ def create_xrf_projections_object(
5558
# transform_tracker=
5659
)
5760

61+
# @property
62+
# def phase_projections(self):
63+
# return self.projections_dict[self._primary_channel]
64+
5865
@property
5966
def angles(self):
6067
return self.projections_dict[self._primary_channel].angles
@@ -87,26 +94,66 @@ def center_of_rotation(self):
8794
@center_of_rotation.setter
8895
def center_of_rotation(self, center_of_rotation: np.ndarray):
8996
self._center_of_rotation = center_of_rotation
90-
for channel, projections in self.projections_dict.items():
97+
for _, projections in self.projections_dict.items():
9198
projections.center_of_rotation = self._center_of_rotation * 1
9299

93100
def apply_staged_shift_to_all_channels(self, device_options: Optional[DeviceOptions] = None):
94-
for channel, projections in self.projections_dict.items():
101+
for _, projections in self.projections_dict.items():
95102
projections.apply_staged_shift(device_options)
96103

97104
def drop_projections_from_all_channels(self, remove_idx: list[int]):
98-
for channel, projections in self.projections_dict.items():
105+
for _, projections in self.projections_dict.items():
99106
projections.drop_projections(remove_idx=remove_idx)
100107

101108
def pin_all_arrays(self):
102-
for channel, projections in self.projections_dict.items():
109+
for _, projections in self.projections_dict.items():
103110
projections.pin_arrays()
104111

112+
def get_projection_matching_shift(
113+
self, initial_shift: Optional[np.ndarray] = None
114+
) -> np.ndarray:
115+
# clear existing astra objects
116+
if self.pma_object is not None:
117+
if hasattr(self.pma_object, "aligned_projections"):
118+
self.pma_object.aligned_projections.volume.clear_astra_objects()
119+
120+
# reset timers
121+
clear_timer_globals()
122+
123+
# close old gui windows
124+
if self.alignment_options.projection_matching.interactive_viewer.close_old_windows:
125+
self.clear_pma_gui_list()
126+
else:
127+
self.pma_gui_list += [self.pma_object.gui]
128+
129+
# run the pma algorithm
130+
self.pma_object, shift = run_projection_matching(
131+
self.projections_dict[self.primary_channel],
132+
initial_shift,
133+
self.alignment_options.projection_matching,
134+
)
135+
136+
# Save the resulting alignment shift
137+
for _, projections in self.projections_dict.items():
138+
projections.shift_manager.stage_shift(
139+
shift=shift,
140+
function_type=enums.ShiftType.FFT,
141+
alignment_options=self.alignment_options.projection_matching,
142+
)
143+
print("Projection-matching shifts stored in shift_manager")
144+
145+
return shift
146+
147+
def clear_pma_gui_list(self):
148+
for gui in self.pma_gui_list:
149+
gui.close()
150+
self.pma_gui_list = []
151+
105152
def get_cross_correlation_shift(self, illum_sum: np.ndarray = None):
106153
clear_timer_globals()
107154
self.cross_correlation_aligner = CrossCorrelationAligner(
108155
projections=self.projections_dict[self._primary_channel],
109-
options=self.task_options.cross_correlation,
156+
options=self.alignment_options.cross_correlation,
110157
)
111158
# Placeholder for actual illum_sum
112159
if illum_sum is None:
@@ -121,7 +168,7 @@ def get_cross_correlation_shift(self, illum_sum: np.ndarray = None):
121168
projections.shift_manager.stage_shift(
122169
shift=shift,
123170
function_type=enums.ShiftType.CIRC,
124-
alignment_options=self.task_options.cross_correlation,
171+
alignment_options=self.alignment_options.cross_correlation,
125172
)
126173
projections.plot_staged_shift("Cross-correlation Shift")
127174
print("Cross-correlation shift stored in shift_manager")

tests/full_tests/cSAXS_e18044_LamNI_201907.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ def update_pma_options(pma_options: opts.ProjectionMatchingOptions, scale: int):
280280
)
281281

282282
# Shift the projections by the projection-matching alignment shift
283-
print(multi_gpu_device_options)
284283
task.phase_projections.apply_staged_shift(multi_gpu_device_options)
285284

286285
# Save the fully aligned task

0 commit comments

Comments
 (0)