Skip to content

Commit e515fd2

Browse files
committed
implement transformation graph traversal
1 parent 8408cf5 commit e515fd2

1 file changed

Lines changed: 61 additions & 4 deletions

File tree

src/ngff_transformations/transform.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,71 @@
11
import numpy as np
22
from xarray import DataArray
33
import networkx as nx
4-
import ome_zarr_models as ozm
4+
from ome_zarr_models._v06.coordinate_transforms import CoordinateSystemIdentifier, Sequence
55

6-
7-
def validata_point_shape(point: np.ndarray, transformation_sequence: ozm._v06.coordinate_transforms.Sequence):
6+
def validata_point_shape(point: np.ndarray, transformation_sequence: Sequence):
87
for transformation in transformation_sequence.transformations:
98
assert len(point) == transformation.ndim, "Point ndim doesn't match transformation ndim"
109

11-
def transform_with_sequence(data: np.ndarray, transformation_sequence: ozm._v06.coordinate_transforms.Sequence,
10+
11+
def get_node(path: str | None = None, name: str | None = None) -> str | CoordinateSystemIdentifier:
12+
if path is None and name is None:
13+
raise ValueError("Both path and name of the coordinate system cannot be None")
14+
if path is None:
15+
return name
16+
if name is None:
17+
return path
18+
return CoordinateSystemIdentifier(path=path, name=name)
19+
20+
def find_walks_in_graph(graph, src_path, src_name, tgt_path, tgt_name):
21+
src_node = get_node(src_path, src_name)
22+
tgt_node = get_node(tgt_path, tgt_name)
23+
24+
graph_walk = list(nx.all_shortest_paths(graph, src_node, tgt_node))[0]
25+
26+
transformation_sequence = []
27+
for i in range(len(graph_walk) - 1):
28+
transformation_sequence.append(graph.get_edge_data(graph_walk[i], graph_walk[i + 1])['transformation'])
29+
30+
transformation_sequence = Sequence(
31+
input=graph_walk[0],
32+
output=graph_walk[-1],
33+
transformations=transformation_sequence
34+
)
35+
return transformation_sequence, (graph_walk[0], graph_walk[-1])
36+
37+
def transform_with_sequence3D(data: np.ndarray, axes: list[str], transformation_sequence: Sequence,
38+
output_axes: list[str]) -> DataArray:
39+
# locate (inside the graph) the coordinate_system classes from the coordinate_system names
40+
# first validate the input data wrt to axes and input_coordinate_system
41+
# 1. check that the data shape is (n x len(axes))
42+
# 2. check that the axes are the same (and in the same order) of the axes of the input_coordinate_system
43+
# traverse the graph to find the transformation -> Transform class
44+
# apply the transformations to the data (code to get inspired from https://github.com/scverse/spatialdata/blob/6652a03b1d66c8902a8f7a159176c51d8c9f823b/src/spatialdata/transformations/operations.py#L212)
45+
# tranform the data
46+
# return the transformed data as tuple (numpy array, output axes from the output coordinate systme)
47+
48+
Y, X, Z, C = data.shape
49+
yy, xx, zz = np.meshgrid(np.arange(Y), np.arange(X), np.arange(Z), indexing='ij')
50+
points = np.stack([xx, yy, zz], axis=-1).reshape(-1, 3)
51+
52+
# validata_point_shape(points[0], transformation_sequence)
53+
54+
transformed_points = np.array([transformation_sequence.transform_point(p) for p in points])
55+
x_prime = transformed_points[:, 0].reshape(Y, X, Z)
56+
y_prime = transformed_points[:, 1].reshape(Y, X, Z)
57+
z_prime = transformed_points[:, 2].reshape(Y, X, Z)
58+
59+
return xarray.DataArray(data,
60+
coords={
61+
"x_prime": (("y", "x", "z"), x_prime),
62+
"y_prime": (("y", "x", "z"), y_prime),
63+
"z_prime": (("y", "x", "z"), z_prime)
64+
},
65+
dims=output_axes)
66+
67+
68+
def transform_with_sequence(data: np.ndarray, axes: list[str], transformation_sequence: Sequence,
1269
output_axes: list[str]) -> DataArray:
1370
# locate (inside the graph) the coordinate_system classes from the coordinate_system names
1471
# first validate the input data wrt to axes and input_coordinate_system

0 commit comments

Comments
 (0)