|
1 | 1 | import numpy as np |
2 | 2 | from xarray import DataArray |
3 | 3 | import networkx as nx |
4 | | -import ome_zarr_models as ozm |
| 4 | +from ome_zarr_models._v06.coordinate_transforms import CoordinateSystemIdentifier, Sequence |
5 | 5 |
|
6 | | - |
7 | | -def validata_point_shape(point: np.ndarray, transformation_sequence: ozm._v06.coordinate_transforms.Sequence): |
| 6 | +def validata_point_shape(point: np.ndarray, transformation_sequence: Sequence): |
8 | 7 | for transformation in transformation_sequence.transformations: |
9 | 8 | assert len(point) == transformation.ndim, "Point ndim doesn't match transformation ndim" |
10 | 9 |
|
11 | | -def transform_with_sequence(data: np.ndarray, transformation_sequence: ozm._v06.coordinate_transforms.Sequence, |
| 10 | + |
| 11 | +def get_node(path: str | None = None, name: str | None = None) -> str | CoordinateSystemIdentifier: |
| 12 | + if path is None and name is None: |
| 13 | + raise ValueError("Both path and name of the coordinate system cannot be None") |
| 14 | + if path is None: |
| 15 | + return name |
| 16 | + if name is None: |
| 17 | + return path |
| 18 | + return CoordinateSystemIdentifier(path=path, name=name) |
| 19 | + |
| 20 | +def find_walks_in_graph(graph, src_path, src_name, tgt_path, tgt_name): |
| 21 | + src_node = get_node(src_path, src_name) |
| 22 | + tgt_node = get_node(tgt_path, tgt_name) |
| 23 | + |
| 24 | + graph_walk = list(nx.all_shortest_paths(graph, src_node, tgt_node))[0] |
| 25 | + |
| 26 | + transformation_sequence = [] |
| 27 | + for i in range(len(graph_walk) - 1): |
| 28 | + transformation_sequence.append(graph.get_edge_data(graph_walk[i], graph_walk[i + 1])['transformation']) |
| 29 | + |
| 30 | + transformation_sequence = Sequence( |
| 31 | + input=graph_walk[0], |
| 32 | + output=graph_walk[-1], |
| 33 | + transformations=transformation_sequence |
| 34 | + ) |
| 35 | + return transformation_sequence, (graph_walk[0], graph_walk[-1]) |
| 36 | + |
| 37 | +def transform_with_sequence3D(data: np.ndarray, axes: list[str], transformation_sequence: Sequence, |
| 38 | + output_axes: list[str]) -> DataArray: |
| 39 | + # locate (inside the graph) the coordinate_system classes from the coordinate_system names |
| 40 | + # first validate the input data wrt to axes and input_coordinate_system |
| 41 | + # 1. check that the data shape is (n x len(axes)) |
| 42 | + # 2. check that the axes are the same (and in the same order) of the axes of the input_coordinate_system |
| 43 | + # traverse the graph to find the transformation -> Transform class |
| 44 | + # apply the transformations to the data (code to get inspired from https://github.com/scverse/spatialdata/blob/6652a03b1d66c8902a8f7a159176c51d8c9f823b/src/spatialdata/transformations/operations.py#L212) |
| 45 | + # tranform the data |
| 46 | + # return the transformed data as tuple (numpy array, output axes from the output coordinate systme) |
| 47 | + |
| 48 | + Y, X, Z, C = data.shape |
| 49 | + yy, xx, zz = np.meshgrid(np.arange(Y), np.arange(X), np.arange(Z), indexing='ij') |
| 50 | + points = np.stack([xx, yy, zz], axis=-1).reshape(-1, 3) |
| 51 | + |
| 52 | + # validata_point_shape(points[0], transformation_sequence) |
| 53 | + |
| 54 | + transformed_points = np.array([transformation_sequence.transform_point(p) for p in points]) |
| 55 | + x_prime = transformed_points[:, 0].reshape(Y, X, Z) |
| 56 | + y_prime = transformed_points[:, 1].reshape(Y, X, Z) |
| 57 | + z_prime = transformed_points[:, 2].reshape(Y, X, Z) |
| 58 | + |
| 59 | + return xarray.DataArray(data, |
| 60 | + coords={ |
| 61 | + "x_prime": (("y", "x", "z"), x_prime), |
| 62 | + "y_prime": (("y", "x", "z"), y_prime), |
| 63 | + "z_prime": (("y", "x", "z"), z_prime) |
| 64 | + }, |
| 65 | + dims=output_axes) |
| 66 | + |
| 67 | + |
| 68 | +def transform_with_sequence(data: np.ndarray, axes: list[str], transformation_sequence: Sequence, |
12 | 69 | output_axes: list[str]) -> DataArray: |
13 | 70 | # locate (inside the graph) the coordinate_system classes from the coordinate_system names |
14 | 71 | # first validate the input data wrt to axes and input_coordinate_system |
|
0 commit comments