diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f48dda33e..edfa12f9c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -35,6 +35,7 @@ timm>=1.0.3 torch>=2.5.0 torchvision>=0.15.0 tqdm>=4.64.1 +transformers>=4.51.1 umap-learn>=0.5.3 wsidicom>=0.18.0 zarr>=2.13.3, <3.0.0 diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py index 2067ef0a1..12e7dacd0 100644 --- a/tests/engines/test_semantic_segmentor.py +++ b/tests/engines/test_semantic_segmentor.py @@ -157,7 +157,6 @@ def _test_qupath_output_patch(output: Path) -> None: assert "Polygon" in geometry_types # When class_dict is None, types are assigned as 0, 1, ... - assert 0 in class_values assert 1 in class_values # Basic sanity check @@ -1071,7 +1070,6 @@ def _test_store_output_patch(output: Path) -> None: if "type" in probs: annotation_types.add(probs.pop("type")) # When class_dict is none, types are assigned as 0, 1, ... - assert 0 in annotation_types assert 1 in annotation_types assert annotations_properties is not None diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py index 0170ac94e..1a5d02b22 100644 --- a/tests/models/test_arch_grandqc.py +++ b/tests/models/test_arch_grandqc.py @@ -105,7 +105,7 @@ def test_grandqc_with_semantic_segmentor( assert Path(output[sample_image]).exists() store = SQLiteStore.open(output[sample_image]) - assert len(store) == 4 + assert len(store) == 3 unique_types = set() @@ -116,7 +116,7 @@ def test_grandqc_with_semantic_segmentor( tissue_area_px += annotation.geometry.area assert 2950000 < tissue_area_px < 2960000 - assert unique_types == {"background", "tissue"} + assert unique_types == {"tissue"} store.close() diff --git a/tests/models/test_arch_sam.py b/tests/models/test_arch_sam.py new file mode 100644 index 000000000..0241c4aca --- /dev/null +++ b/tests/models/test_arch_sam.py @@ -0,0 +1,61 @@ +"""Unit test package for SAM.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pytest +import torch + +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device + +ON_GPU = toolbox_env.has_gpu() + +# Test pretrained Model ============================= + + +def test_functional_sam(remote_sample: Callable) -> None: + """Test for SAM.""" + # convert to pathlib Path to prevent wsireader complaint + tile_path = Path(remote_sample("patch-extraction-vf")) + img = imread(tile_path) + + # test creation + + model = SAM(device=select_device(on_gpu=ON_GPU)) + + # create image patch and prompts + points = np.array([[[64, 64]]]) + boxes = np.array([[[64, 64, 128, 128]]]) + + # test preproc + tensor = torch.from_numpy(img) + patch = np.expand_dims(model.preproc(tensor), axis=0) + patch = model.preproc(patch) + + # test inference + + mask_output, score_output = model.infer_batch( + model, patch, points, device=select_device(on_gpu=ON_GPU) + ) + + assert mask_output is not None, "Output should not be None" + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + mask_output, score_output = model.infer_batch( + model, patch, box_coords=boxes, device=select_device(on_gpu=ON_GPU) + ) + + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + # test error when no prompts provided + with pytest.raises( + ValueError, + match=r"At least one of point_coords or box_coords must be provided.", + ): + _ = model.infer_batch(model, patch, device=select_device(on_gpu=ON_GPU)) diff --git a/tests/test_app_bokeh.py b/tests/test_app_bokeh.py index 885a0cd91..b0516fa28 100644 --- a/tests/test_app_bokeh.py +++ b/tests/test_app_bokeh.py @@ -2,6 +2,7 @@ from __future__ import annotations +import importlib import importlib.resources as importlib_resources import io import json @@ -9,6 +10,7 @@ import re import shutil import time +import types from pathlib import Path from types import SimpleNamespace from typing import TYPE_CHECKING @@ -21,6 +23,9 @@ from bokeh.application import Application from bokeh.application.handlers import FunctionHandler from bokeh.events import ButtonClick, DoubleTap, MenuItemClick +from bokeh.models import ColorBar +from bokeh.models.tiles import WMTSTileSource +from bokeh.plotting import figure from flask_cors import CORS from matplotlib import colormaps from PIL import Image @@ -32,7 +37,7 @@ from tiatoolbox.visualization.ui_utils import get_level_by_extent if TYPE_CHECKING: # pragma: no cover - from collections.abc import Generator + from collections.abc import Callable, Generator from bokeh.document import Document @@ -616,6 +621,52 @@ def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> No assert len(main.UI["type_column"].children) == 1 +def test_sam_segment(doc: Document, data_path: pytest.TempPathFactory) -> None: + """Test running SAM on a box.""" + slide_select = doc.get_model_by_name("slide_select0") + slide_select.value = [data_path["slide2"].name] + run_button = doc.get_model_by_name("to_model0") + assert len(main.UI["color_column"].children) == 0 + slide_select.value = [data_path["slide1"].name] + # set up a box selection + main.UI["box_source"].data = { + "x": [1200], + "y": [-2000], + "width": [400], + "height": [400], + } + + # select SAM model and run it on box + model_select = doc.get_model_by_name("model_drop0") + model_select.value = "SAM" + + click = ButtonClick(run_button) + run_button._trigger_event(click) + assert len(main.UI["color_column"].children) > 0 + + # test save functionality + save_button = doc.get_model_by_name("save_button0") + click = ButtonClick(save_button) + save_button._trigger_event(click) + saved_path = ( + data_path["base_path"] + / "overlays" + / (data_path["slide1"].stem + "_saved_anns.db") + ) + assert saved_path.exists() + + # load an overlay with different types + cprop_select = doc.get_model_by_name("cprop0") + cprop_select.value = ["prob"] + layer_drop = doc.get_model_by_name("layer_drop0") + click = MenuItemClick(layer_drop, str(data_path["dat_anns"])) + layer_drop._trigger_event(click) + assert main.UI["vstate"].types == ["annotation"] + # check the per-type ui controls have been updated + assert len(main.UI["color_column"].children) == 1 + assert len(main.UI["type_column"].children) == 1 + + def test_alpha_sliders(doc: Document) -> None: """Test sliders for adjusting slide and overlay alpha.""" slide_alpha = doc.get_model_by_name("slide_alpha0") @@ -925,6 +976,33 @@ def test_populate_slide_list(doc: Document, data_path: pytest.TempPathFactory) - assert len(slide_select.options) == 4 +def test_clear_overlays(doc: Document, data_path: pytest.TempPathFactory) -> None: + """Test clearing overlays.""" + slide_select = doc.get_model_by_name("slide_select0") + slide_select.value = [data_path["slide1"].name] + + # load an annotation store + layer_drop = doc.get_model_by_name("layer_drop0") + click = MenuItemClick(layer_drop, str(data_path["annotations"])) + layer_drop._trigger_event(click) + assert "overlay" in main.UI["vstate"].layer_dict + + # now clear the overlays + clear_button = doc.get_model_by_name("clear_button0") + click = ButtonClick(clear_button) + clear_button._trigger_event(click) + assert "overlay" not in main.UI["vstate"].layer_dict + assert ( + len(main.UI["vstate"].layer_dict) == 5 + ) # slide & empty box/pt/edge/node renderers + + # click again - should do nothing and not error + click = ButtonClick(clear_button) + clear_button._trigger_event(click) + assert "overlay" not in main.UI["vstate"].layer_dict + assert len(main.UI["vstate"].layer_dict) == 5 + + def test_channel_color_ui_callbacks( doc: Document, data_path: pytest.TempPathFactory, @@ -1047,3 +1125,693 @@ def fake_exit() -> None: monkeypatch.setattr(app_hooks, "sys", SimpleNamespace(exit=fake_exit)) app_hooks.on_session_destroyed(_DummySessionContext("user-2")) assert exited + + +def test_dummyattr_stores_value() -> None: + """Ensure that DummyAttr correctly stores the provided value. + + This test verifies that the constructor assigns the input value + to the `item` attribute. + """ + obj = main.DummyAttr("hello") + assert obj.item == "hello" + + +@pytest.mark.parametrize( + "value", + [ + 123, + 3.14, + {"a": 1}, + [1, 2, 3], + (1, 2), + None, + ], +) +def test_dummyattr_accepts_any_type(value: object) -> None: + """Confirm that DummyAttr accepts and stores values of any type.""" + obj = main.DummyAttr(value) + assert obj.item is value + + +class DummyResponse: + """A dummy HTTP response object containing invalid JSON.""" + + text: str = "not valid json" + + +def test_get_channel_info_logs_json_error( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + """Test that get_channel_info logs a warning.""" + + class DummySession: + """A dummy session whose GET request returns invalid JSON.""" + + def get(self, url: str) -> DummyResponse: # skipcq: PYL-R0201 # noqa: ARG002 + """Dummy get function.""" + return DummyResponse() + + # Patch __getitem__ on the UIWrapper *class* so UI["s"] returns DummySession() + monkeypatch.setattr( + main.UI.__class__, + "__getitem__", + lambda _self, key: DummySession() if key == "s" else None, + ) + + with caplog.at_level("WARNING"): + result = main.get_channel_info() + + assert result == ({}, []) + + assert any("Error decoding JSON" in message for message in caplog.messages) + + +# --------------------------------------------------------------------------- +# Helper stubs (purely for testing) +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Helper stubs (purely for testing) — ruff/DeepSource compliant +# --------------------------------------------------------------------------- + + +class FakeInfo: + """Lightweight stub for WSIReader.info.""" + + def __init__(self, path: Path) -> None: + """Initialize fake slide info with plausible values.""" + self.mpp = [0.25, 0.25] + self.slide_dimensions = (2000, 1500) + self.file_path = path + + def as_dict(self) -> dict[str, object]: + """Return dictionary representation of slide metadata.""" + return { + "file_path": self.file_path, + "mpp": self.mpp, + "dims": self.slide_dimensions, + } + + +class FakeWSIReader: + """Stub mimicking tiatoolbox WSIReader.""" + + def __init__(self, path: Path) -> None: + """Create a fake reader with associated FakeInfo.""" + self.info = FakeInfo(path) + + @staticmethod + def open(path: Path | str) -> FakeWSIReader: + """Replacement for WSIReader.open that returns a controlled stub.""" + return FakeWSIReader(Path(path)) + + @staticmethod + def slide_thumbnail() -> np.ndarray: + """Return a small fake thumbnail array.""" + return 255 * np.ones((20, 20, 3), dtype="uint8") + + @staticmethod + def read_bounds(*_args: object, **_kwargs: object) -> np.ndarray: + """Return a fake array for ROI extraction.""" + return 255 * np.ones((10, 10, 3), dtype="uint8") + + +class FakeZoomifyGenerator: + """Stub mimicking ZoomifyGenerator with a fixed zoom level count.""" + + def __init__(self, *_args: object, **_kwargs: object) -> None: + """Initialize with a default level count.""" + self.level_count = 7 + + +class FakeResp: + """Lightweight fake requests.Response used for stubbing.""" + + class CookieDict(dict): + """Subclass dict to mimic requests cookie access.""" + + def get(self, key: str, default: object | None = None) -> object: + """Return cookie value or default.""" + return super().get(key, default) + + def __init__(self, text: str = "", cookies: dict[str, str] | None = None) -> None: + """Initialize fake response with text and optional cookies.""" + self.text = text + self._cookies = FakeResp.CookieDict(cookies or {}) + + @property + def cookies(self) -> dict[str, str]: + """Return fake cookie jar.""" + return self._cookies + + +class FakeSession: + """Minimal stub for requests.Session.""" + + def __init__(self: FakeSession) -> None: + """Initialize fake session with no proxies and empty mounts.""" + self.trust_env = False + self.proxies: dict[str, str | None] = {} + self.mounted: dict[str, object] = {} + + def mount(self: FakeSession, scheme: str, adapter: object) -> None: + """Record mounted adapters (no real effect).""" + self.mounted[scheme] = adapter + + @staticmethod + def get(url: str, *_args: object, **_kwargs: object) -> FakeResp: + """Return canned JSON or session cookie.""" + if url.endswith("/tileserver/session_id"): + return FakeResp(cookies={"session_id": "test_user"}) + return FakeResp(text="{}") + + @staticmethod + def put( + url: str, + *_args: object, + _data: dict[str, object] | None = None, + **_kwargs: object, + ) -> FakeResp: + """Return controlled values based on the endpoint.""" + if url.endswith("/tileserver/overlay"): + return FakeResp(text='"slide"') + return FakeResp(text='"ok"') + + @staticmethod + def post( + _url: str, + *_args: object, + _data: dict[str, object] | None = None, + **_kwargs: object, + ) -> FakeResp: + """Return generic OK response.""" + return FakeResp(text='"ok"') + + +class FakeRequest: + """Stub for request object inside session context.""" + + def __init__(self, arguments: dict[str, list[bytes]] | None = None) -> None: + """Initialize with an optional arguments dict.""" + self.arguments = arguments or {} + + +class FakeSessionContext: + """Stub for Bokeh session context.""" + + def __init__(self, arguments: dict[str, list[bytes]] | None = None) -> None: + """Create a session context holding a FakeRequest.""" + self.request = FakeRequest(arguments=arguments) + + +class FakeDoc: + """Stub for Bokeh Document used during import/setup.""" + + def __init__( + self, + *, + with_session: bool = False, + arguments: dict[str, list[bytes]] | None = None, + ) -> None: + """Initialize a fake Bokeh document.""" + self._sc = FakeSessionContext(arguments) if with_session else None + self.template_variables: dict[str, object] = {} + self._callbacks: list[tuple[object, int]] = [] + self._roots: list[object] = [] + self.title: str = "" + + @property + def session_context(self) -> FakeSessionContext | None: + """Return fake session context if enabled.""" + return self._sc + + def add_periodic_callback(self, fn: object, ms: int) -> None: + """Record callbacks for later introspection.""" + self._callbacks.append((fn, ms)) + + def add_root(self, root: object) -> None: + """Record Bokeh layout roots.""" + self._roots.append(root) + + +# --------------------------------------------------------------------------- +# Controlled import of main.py under patching +# --------------------------------------------------------------------------- + + +def reload_main( + monkeypatch: pytest.MonkeyPatch, + *, + with_session: bool = False, + req_args: dict[str, list[bytes]] | None = None, + pre_import_patch: Callable | None = None, +) -> object: + """Reload ``main`` in a controlled environment.""" + # Apply pre-import patches FIRST (critical). + if pre_import_patch is not None: + pre_import_patch(monkeypatch) + + # Patch curdoc + monkeypatch.setattr( + "bokeh.io.curdoc", + lambda: FakeDoc(with_session=with_session, arguments=req_args), + raising=False, + ) + + # Core stub patches used everywhere in the suite + monkeypatch.setattr("requests.Session", FakeSession, raising=True) + monkeypatch.setattr( + "tiatoolbox.wsicore.wsireader.WSIReader.open", + FakeWSIReader.open, + raising=True, + ) + monkeypatch.setattr( + "tiatoolbox.tools.pyramid.ZoomifyGenerator", + FakeZoomifyGenerator, + raising=True, + ) + + # Safe reload with all patches applied + return importlib.reload(main) + + +# --------------------------------------------------------------------------- +# Actual tests +# --------------------------------------------------------------------------- + + +def test_populate_table_active_channels_false( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test populate_table() branch where active_channels == []. + + Ensures: + - Table rows populate from colors. + - No indices selected. + + """ + main = reload_main(monkeypatch, with_session=False) + + channel_ui = main.create_channel_color_ui() + main.win_dicts.clear() + main.win_dicts.append({"channel_select": channel_ui}) + main.UI.active = 0 + + monkeypatch.setattr( + main, + "get_channel_info", + lambda: ({"C1": (1, 2, 3), "C2": (4, 5, 6)}, []), + raising=True, + ) + + main.populate_table() + tables = main.UI["channel_select"].children[1].children[0].children + + assert tables[0].source.data["channels"] == ["C1", "C2"] + assert len(tables[1].source.data["colors"]) == 2 + assert tables[0].source.selected.indices == [] + + +def test_initialise_slide_initial_view_limits( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test initialise_slide() branch where slide_name appears in initial_views. + + Ensures p.x_range/y_range are set from config. + + """ + main = reload_main(monkeypatch, with_session=False) + + slide_path = Path("/tmp/s1.svs") # noqa: S108 + vstate = main.ViewerState(slide_path) + + p = figure(width=400, height=400) + + main.win_dicts.clear() + main.win_dicts.append({"vstate": vstate, "p": p}) + main.UI.active = 0 + + main.doc_config.config["initial_views"] = {"s1": [10, 20, 110, 120]} + main.initialise_slide() + + assert p.x_range.start == 10 + assert p.x_range.end == 110 + assert p.y_range.start == -120 + assert p.y_range.end == -20 + + +def test_slide_select_cb_auto_load_triggers_layer_drop( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test slide_select_cb() when auto_load=True. + + Ensures: + - slide_select_cb() runs without error + - auto_load triggers layer_drop_cb() exactly once + """ + main = reload_main(monkeypatch, with_session=False) + + class MiniHover: + """Initialize MiniHover.""" + + tooltips: object | None = None + + p = figure(width=400, height=300) + + # renderer[0] must be a TileRenderer for change_tiles() + dummy_ts = WMTSTileSource(url="http://dummy/{Z}/{X}/{Y}.png") + p.add_tile(dummy_ts) + + # Fill remaining permanent renderers + for _ in range(main.N_PERMANENT_RENDERERS - 1): + p.circle([], []) + + class FakeTable: + """Create a minimal fake channel_select matching expected structure.""" + + def __init__(self: FakeTable) -> None: + self.source = main.ColumnDataSource({"channels": [], "dummy": []}) + self.selected = types.SimpleNamespace(indices=[]) + + class FakeChannelSelect: + """Must satisfy: UI["channel_select"].children[1].children[0].children.""" + + def __init__(self: FakeChannelSelect) -> None: + table1 = FakeTable() # channels table + table2 = FakeTable() # colors table + + inner = types.SimpleNamespace(children=[table1, table2]) + outer = types.SimpleNamespace(children=[inner]) + + self.children = [None, outer] + + class DummyCol: + """Initialize DummyCol test.""" + + def __init__(self: DummyCol) -> None: + self.children: list[object] = [] + + win = { + "p": p, + "vstate": main.ViewerState(Path("/tmp/initial.svs")), # noqa: S108 + "pt_source": main.ColumnDataSource({"x": [], "y": []}), + "box_source": main.ColumnDataSource( + {"x": [], "y": [], "width": [], "height": []} + ), + "node_source": main.ColumnDataSource({"x_": [], "y_": [], "node_color_": []}), + "edge_source": main.ColumnDataSource( + {"x0_": [], "y0_": [], "x1_": [], "y1_": []} + ), + "hover": MiniHover(), + "layer_drop": main.Dropdown( + label="Add Overlay", menu=[("foo.json", "foo.json")] + ), + "s": FakeSession(), + # Required UI structures + "color_column": DummyCol(), + "type_column": DummyCol(), + "channel_select": FakeChannelSelect(), + # Required by change_tiles() + "user": "test_user", + } + + main.win_dicts.clear() + main.win_dicts.append(win) + main.UI.active = 0 + + # Enable auto-load + main.doc_config.config["auto_load"] = True + + # Avoid filesystem operations + monkeypatch.setattr( + main, + "populate_layer_list", + lambda *_args, **_kwargs: None, + raising=True, + ) + + # Spy on layer_drop_cb + called: list[str] = [] + + def spy(attr: object) -> None: + """Spy attribute for layer_drop_cb.""" + called.append(attr.item) + + monkeypatch.setattr(main, "layer_drop_cb", spy, raising=True) + + # ACT + new_slide_name = [Path("/tmp/newslide.svs").name] # noqa: S108 + main.slide_select_cb(None, None, new=new_slide_name) + + # ASSERT + assert called == ["foo.json"] + + +def test_layer_drop_cb_resp_equals_slide( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test branch in layer_drop_cb() where resp == "slide". + + Expected behavior: + - add_layer() is NOT called + - change_tiles("slide") IS called (this is correct behaviour) + + """ + main = reload_main(monkeypatch, with_session=False) + + # Prepare minimal window with a FakeSession + main.win_dicts.clear() + main.win_dicts.append({"s": FakeSession()}) + main.UI.active = 0 + + # Spy: add_layer() must not be called + monkeypatch.setattr( + main, + "add_layer", + lambda *_a, **_k: (_ for _ in ()).throw( + AssertionError("add_layer should NOT be called") + ), + raising=True, + ) + + # Spy: change_tiles() SHOULD be called exactly once + change_calls: list[str] = [] + + def change_spy(arg: str) -> None: + """Change spy.""" + change_calls.append(arg) + + monkeypatch.setattr(main, "change_tiles", change_spy, raising=True) + + # Trigger overlay load on a non-annotation file => uses resp == "slide" + attr = main.DummyAttr("/tmp/overlay.png") # noqa: S108 + main.layer_drop_cb(attr) + + # Assert correct behavior + assert change_calls == ["slide"], ( + "change_tiles('slide') must be called when resp == 'slide'" + ) + + +def test_gather_ui_elements_ui_elements_1_and_2_true( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test gather_ui_elements() when config includes ui_elements_1 and ui_elements_2. + + Ensures only enabled elements are included. + """ + main = reload_main(monkeypatch, with_session=False) + + main.doc_config.config["ui_elements_1"] = { + "slide_select": 1, + "layer_drop": 1, + "slide_row": 1, + "overlay_row": 1, + "filter_input": 0, + "cprop_input": 1, + "cmap_row": 1, + "type_cmap_select": 1, + "model_row": 1, + "clear_button": 1, + "type_select_row": 1, + } + + main.doc_config.config["ui_elements_2"] = { + "opt_buttons": 1, + "pt_size_spinner": 1, + "edge_size_spinner": 0, + "res_switch": 1, + "channel_select": 1, + } + + vstate = main.ViewerState(Path("/tmp/v.svs")) # noqa: S108 + ui_layout, extra_options, _elements = main.gather_ui_elements(vstate, win_num=0) + + present = {c.name for c in ui_layout.children if hasattr(c, "name")} + assert "filter0" not in present + assert "slide_select0" in present + assert "cprop0" in present + + assert len(extra_options.children) == 4 + + +def test_make_window_sets_user_when_session_context_true( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test branch where make_window() sees session_context present. + + Ensures: + - 'user' cookie gets written to session_context.request.arguments + + """ + # Import main WITHOUT a session context so module-level do_doc is False + main = reload_main(monkeypatch, with_session=False) + + # Prepare a FakeDoc WITH a session_context + # this is the one we want make_window to see + req_args: dict[str, list[bytes]] = {} + fake_doc = FakeDoc(with_session=True, arguments=req_args) + + # CRITICAL FIX: patch main.curdoc, not bokeh.io.curdoc + monkeypatch.setattr(main, "curdoc", lambda: fake_doc, raising=False) + + # Create a vstate and call make_window + vstate = main.ViewerState(Path("/tmp/one.svs")) # noqa: S108 + main.make_window(vstate) + + # Assertions: make_window must write "user" + sc = main.curdoc().session_context + assert sc is not None + assert "user" in sc.request.arguments + assert sc.request.arguments["user"] == "test_user" + + +def test_make_window_hover_nodes_edges_colorbar( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test make_window() branches. + + - opts.hover_on == 1 (skip disabling inspect tool) + - nodes_on == False + - edges_on == False + - colorbar_on == 0 + + """ + main = reload_main(monkeypatch, with_session=False) + + main.doc_config.config["opts"] = { + "hover_on": 1, + "nodes_on": False, + "edges_on": False, + "colorbar_on": 0, + } + + vstate = main.ViewerState(Path("/tmp/abc.svs")) # noqa: S108 + win = main.make_window(vstate) + p = win["p"] + + assert p.toolbar.active_inspect is not None + + node_renderer = p.renderers[-2] + assert node_renderer.glyph.fill_alpha == 0 + assert node_renderer.glyph.line_alpha == 0 + + edge_renderer = p.renderers[-1] + assert edge_renderer.visible is False + + assert not p.select(type=ColorBar) + + +def test_make_window_edges_on_true(monkeypatch: pytest.MonkeyPatch) -> None: + """Import main with NO session so do_doc=False.""" + main = reload_main(monkeypatch, with_session=False) + + # Force config branch + main.doc_config.config["opts"] = {"edges_on": True} + + v = main.ViewerState(Path("/tmp/a.svs")) # noqa: S108 + win = main.make_window(v) + + # When edges_on=True, renderer must remain visible (i.e., NOT hidden) + p = win["p"] + edge_renderer = p.renderers[win["vstate"].layer_dict["edges"]] + + assert edge_renderer.visible is True + + +def test_docconfig_no_config_file( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Test _get_config when no *config.json file is present.""" + # Avoid module-level auto-setup + main = reload_main(monkeypatch, with_session=False) + + # Provide only base_folder; overlays contain no *config.json + slides_dir = tmp_path / "slides" + overlays_dir = tmp_path / "overlays" + slides_dir.mkdir() + overlays_dir.mkdir() + + main.req_args = {} # no extra query args + dc = main.doc_config + dc.set_sys_args(["prog", str(tmp_path)]) + + # Execute and ensure fallback path doesn't crash and sets initial_views + dc._get_config() + assert "initial_views" in dc.config + assert dc.config["initial_views"] == {} + + +def test_docconfig_get_config_basefolder_and_demo( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Test _get_config handles base_folder with demo query parameter.""" + # Avoid module-level auto-setup + main = reload_main(monkeypatch, with_session=False) + + # Simulate argv with only base folder (len==2) and a demo in req_args + dc = main.doc_config + dc.set_sys_args(["prog", str(tmp_path)]) + main.req_args = {"demo": [b"DemoX"]} + + # Invoke _get_config + dc._get_config() + + # Expect updated slide/overlay/base folders under DemoX + assert dc.config["demo_name"] == "DemoX" + assert dc.config["slide_folder"].name == "slides" + assert dc.config["overlay_folder"].name == "overlays" + assert dc.config["slide_folder"].parent.name == "DemoX" + assert dc.config["overlay_folder"].parent.name == "DemoX" + + +def test_docconfig_request_args_slide_and_window( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Test _get_config processes ?slide= and ?window= request args.""" + # Avoid module-level auto-setup + main = reload_main(monkeypatch, with_session=False) + + # Create folders to satisfy path arithmetic in _get_config + slides_dir = tmp_path / "slides" + overlays_dir = tmp_path / "overlays" + slides_dir.mkdir() + overlays_dir.mkdir() + + # Provide fake req_args that main._get_config reads + main.req_args = { + "slide": [b"S1.svs"], + "window": [b"[10,20,100,200]"], + } + + dc = main.doc_config + dc.set_sys_args(["prog", str(tmp_path)]) + dc._get_config() + + assert dc.config["first_slide"] == "S1.svs" + assert dc.config["initial_views"]["S1"] == [10, 20, 100, 200] diff --git a/tests/test_server_bokeh.py b/tests/test_server_bokeh.py index a56be1774..b87b9c9a4 100644 --- a/tests/test_server_bokeh.py +++ b/tests/test_server_bokeh.py @@ -88,13 +88,14 @@ def test_slides_available(bk_session: ClientSession) -> None: """Test that the slides and overlays are available.""" doc = bk_session.document slide_select = doc.get_model_by_name("slide_select0") - # check there are two available slides assert len(slide_select.options) == 2 # check that the overlays are available. - slide_select.value = ["CMU-1-Small-region.svs"] + slide_select.value = ["CMU-1.ndpi"] layer_drop = doc.get_model_by_name("layer_drop0") assert len(layer_drop.menu) == 2 + slide_select.value = ["CMU-1-Small-region.svs"] + assert len(layer_drop.menu) == 2 bk_session.document.clear() assert len(bk_session.document.roots) == 0 diff --git a/tests/test_tileserver.py b/tests/test_tileserver.py index d987027da..474dce26a 100644 --- a/tests/test_tileserver.py +++ b/tests/test_tileserver.py @@ -367,6 +367,16 @@ def test_load_save_annotations(app: TileServer, track_tmp_path: Path) -> None: assert len(store) == num_annotations + 2 +def test_clear_overlays(app: TileServer) -> None: + """Test clearing overlays.""" + with app.test_client() as client: + response = client.put("/tileserver/clear_overlays") + assert response.status_code == 200 + assert response.content_type == "text/html; charset=utf-8" + # check that the overlay has been correctly cleared + assert "overlay" not in app.pyramids["default"] + + def test_load_annotations_empty( empty_app: TileServer, track_tmp_path: Path, diff --git a/tests/test_tileserver_channels.py b/tests/test_tileserver_channels.py index a8bd6640b..127f0d008 100644 --- a/tests/test_tileserver_channels.py +++ b/tests/test_tileserver_channels.py @@ -41,17 +41,21 @@ def __init__( ) -> None: self.info = info or _FakeInfo() self.post_proc = post_proc - self._thumb_called = 0 + self._called = 0 - def slide_thumbnail( - self, *, resolution: float = 8.0, units: str = "mpp" + def read_rect( + self, + location: tuple[float, float] = (0, 0), + size: tuple[float, float] = (100, 100), + resolution: int = 0, + units: str = "level", ) -> np.ndarray: - """Fake thumbnail method that counts calls.""" - _ = (resolution, units) # mark as used to satisfy Ruff + """Fake read_rect method that counts calls.""" + _ = (resolution, units, location) # mark as used to satisfy Ruff # Parameters are part of the real Slide API; unused by the fake. - self._thumb_called += 1 + self._called += 1 # returning an array avoids PIL issues if used elsewhere - return np.zeros((8, 8, 3), dtype=np.uint8) + return np.zeros((size[0], size[1], 3), dtype=np.uint8) @pytest.fixture @@ -81,7 +85,7 @@ def session_id(client: TileServer) -> str: return sid -def test_get_channels_populated_and_triggers_thumbnail_when_not_validated( +def test_get_channels_populated_and_triggers_read_when_not_validated( app: TileServer, client: TileServer, session_id: str ) -> None: """Covers lines in get_channels by MultichannelToRGB and is_validated gate.""" @@ -106,14 +110,14 @@ def test_get_channels_populated_and_triggers_thumbnail_when_not_validated( # JSON serializes tuples to lists assert payload["channels"] == {"c0": [1.0, 0.0, 0.0], "c1": [0.0, 1.0, 0.0]} assert payload["active"] == ["c0", "c1"] - # Should have forced a thumbnail when not validated - assert slide._thumb_called == 1 + # Should have forced a read_rect when not validated + assert slide._called == 1 - # Call again with already validated state to ensure no extra thumbnailing + # Call again with already validated state to ensure no extra read_rect calls app.layers[sid]["slide"].post_proc.is_validated = True r2 = client.get("/tileserver/channels") assert r2.status_code == 200 - assert slide._thumb_called == 1 # unchanged + assert slide._called == 1 # unchanged def test_set_channels_updates_dicts_and_marks_unvalidated( diff --git a/tests/test_utils.py b/tests/test_utils.py index 0cb64206f..df60c111a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1929,8 +1929,9 @@ def test_dict_to_store_semantic_segment() -> None: class_dict=None, save_path=None, output_type="annotationstore", + ignore_index=0, ) - assert len(store_) == 3 + assert len(store_) == 2 annotations_ = store_.values() diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 4f38af495..c8aa06d88 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -589,6 +589,7 @@ fcn-tissue_mask: patch_output_shape: [512, 512] stride_shape: [450, 450] save_resolution: {'units': 'mpp', 'resolution': 8.0} + ignore_index: 0 fcn_resnet50_unet-bcss: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -610,6 +611,7 @@ fcn_resnet50_unet-bcss: patch_output_shape: [512, 512] stride_shape: [450, 450] save_resolution: {'units': 'mpp', 'resolution': 0.25} + ignore_index: 0 unet_tissue_mask_tsef: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -635,6 +637,7 @@ unet_tissue_mask_tsef: patch_output_shape: [512, 512] stride_shape: [256, 256] save_resolution: {'units': 'baseline', 'resolution': 1.0} + ignore_index: 0 hovernet_fast-pannuke: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -666,6 +669,7 @@ hovernet_fast-pannuke: patch_output_shape: [164, 164] stride_shape: [164, 164] save_resolution: {'units': 'mpp', 'resolution': 0.25} + ignore_index: 0 hovernet_fast-monusac: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -696,6 +700,7 @@ hovernet_fast-monusac: patch_output_shape: [164, 164] stride_shape: [164, 164] save_resolution: {'units': 'mpp', 'resolution': 0.25} + ignore_index: 0 hovernet_original-consep: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -748,6 +753,7 @@ hovernet_original-kumar: patch_output_shape: [80, 80] stride_shape: [80, 80] save_resolution: {'units': 'mpp', 'resolution': 0.25} + ignore_index: 0 hovernetplus-oed: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -784,6 +790,7 @@ hovernetplus-oed: patch_output_shape: [164, 164] stride_shape: [164, 164] save_resolution: {'units': 'mpp', 'resolution': 0.50} + ignore_index: 0 micronet-consep: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -804,6 +811,7 @@ micronet-consep: patch_output_shape: [252, 252] stride_shape: [150, 150] save_resolution: {'units': 'mpp', 'resolution': 0.25} + ignore_index: 0 mapde-crchisto: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -925,6 +933,7 @@ nuclick_original-pannuke: patch_input_shape: [128, 128] patch_output_shape: [128, 128] save_resolution: {'units': 'baseline', 'resolution': 1.0} + ignore_index: 0 nuclick_light-pannuke: hf_repo_id: TIACentre/TIAToolbox_pretrained_weights @@ -947,6 +956,7 @@ nuclick_light-pannuke: patch_input_shape: [128, 128] patch_output_shape: [128, 128] save_resolution: {'units': 'baseline', 'resolution': 1.0} + ignore_index: 0 grandqc_tissue_detection: hf_repo_id: TIACentre/GrandQC_Tissue_Detection @@ -954,6 +964,10 @@ grandqc_tissue_detection: class: grandqc.GrandQCModel kwargs: num_output_channels: 2 + class_dict: { + 0: "Background", + 1: "Tissue", + } ioconfig: class: io_config.IOSegmentorConfig kwargs: @@ -965,6 +979,7 @@ grandqc_tissue_detection: patch_output_shape: [512, 512] stride_shape: [256, 256] save_resolution: {'units': 'mpp', 'resolution': 10.0} + ignore_index: 0 KongNet_CoNIC_1: hf_repo_id: TIACentre/KongNet_pretrained_weights diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index caa4b5ebe..0885c99ad 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -9,6 +9,7 @@ from .architecture.mapde import MapDe from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick +from .architecture.sam import SAM from .architecture.sccnn import SCCNN from .dataset import PatchDataset, WSIPatchDataset from .engine.deep_feature_extractor import DeepFeatureExtractor @@ -22,9 +23,11 @@ from .engine.nucleus_detector import NucleusDetector from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor from .engine.patch_predictor import PatchPredictor +from .engine.prompt_segmentor import PromptSegmentor from .engine.semantic_segmentor import SemanticSegmentor __all__ = [ + "SAM", "SCCNN", "DeepFeatureExtractor", "HoVerNet", @@ -42,6 +45,7 @@ "NucleusInstanceSegmentor", "PatchDataset", "PatchPredictor", + "PromptSegmentor", "SemanticSegmentor", "WSIPatchDataset", "architecture", diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index 2e3ce227d..0e02af121 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -470,7 +470,9 @@ class GrandQCModel(ModelABC): """ - def __init__(self: GrandQCModel, num_output_channels: int = 2) -> None: + def __init__( + self: GrandQCModel, num_output_channels: int = 2, class_dict: dict | None = None + ) -> None: """Initialize GrandQCModel. Sets up the UNet++ decoder, EfficientNet encoder, and segmentation head @@ -479,6 +481,8 @@ def __init__(self: GrandQCModel, num_output_channels: int = 2) -> None: Args: num_output_channels (int): Number of output classes. Defaults to 2 (Tissue and Background). + class_dict (dict | None): + Optional dictionary mapping class names to indices. Defaults to None. """ super().__init__() @@ -505,6 +509,7 @@ def __init__(self: GrandQCModel, num_output_channels: int = 2) -> None: ) self.name = "unetplusplus-efficientnetb0" + self.class_dict = class_dict def forward( # skipcq: PYL-W0613 self: GrandQCModel, @@ -576,7 +581,7 @@ def postproc(image: np.ndarray) -> np.ndarray: Returns: np.ndarray: - Binary tissue mask where 0 = Tissue and 1 = Background. + Binary tissue mask where 1 = Tissue and 0 = Background. Example: >>> probs = np.random.rand(256, 256, 2) diff --git a/tiatoolbox/models/architecture/sam.py b/tiatoolbox/models/architecture/sam.py new file mode 100644 index 000000000..98198d757 --- /dev/null +++ b/tiatoolbox/models/architecture/sam.py @@ -0,0 +1,235 @@ +"""Define SAM architecture.""" + +from __future__ import annotations + +import numpy as np +import torch +from PIL import Image +from transformers import SamModel, SamProcessor + +from tiatoolbox.models.models_abc import ModelABC + + +class SAM(ModelABC): + """Segment Anything Model (SAM) Architecture. + + Meta AI's zero-shot segmentation model. + SAM is used for interactive general-purpose segmentation. + + Currently supports SAM. + + SAM accepts an RGB image patch along with a list of point and bounding + box coordinates as prompts. + + Args: + model_path (str): + Path to the model (huggingface). + device (str): + Device to run inference on. + + Examples: + >>> # instantiate SAM with checkpoint path and model type + >>> sam = SAM( + ... model_path="facebook/sam-vit-b", + ... device="cuda", + ... ) + """ + + def __init__( + self: SAM, + model_path: str = "facebook/sam-vit-huge", + *, + device: str = "cpu", + ) -> None: + """Initialize :class:`SAM`.""" + super().__init__() + self.net_name = "SAM" + self.device = device + + self.model = SamModel.from_pretrained(model_path).to(device) + self.processor = SamProcessor.from_pretrained(model_path) + + def _process_prompts( + self: SAM, + image: np.ndarray, + embeddings: torch.Tensor, + orig_sizes: torch.Tensor, + reshaped_sizes: torch.Tensor, + points: list | None = None, + boxes: list | None = None, + point_labels: list | None = None, + ) -> tuple[list, list]: + """Process prompts and return masks and scores.""" + inputs = self.processor( + image, + input_points=points, + input_labels=point_labels, + input_boxes=boxes, + return_tensors="pt", + ).to(self.device) + + # Replaces pixel_values with image embeddings + inputs.pop("pixel_values", None) + inputs.update( + { + "image_embeddings": embeddings, + "original_sizes": orig_sizes, + "reshaped_input_sizes": reshaped_sizes, + } + ) + + with torch.inference_mode(): + # Forward pass through the model + outputs = self.model(**inputs, multimask_output=False) + image_masks = self.processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + ) + image_scores = outputs.iou_scores.cpu() + + return image_masks, image_scores + + def forward( # skipcq: PYL-W0221 + self: SAM, + imgs: list, + point_coords: list | None = None, + box_coords: list | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """PyTorch method. Defines forward pass on each image in the batch. + + Note: This architecture only uses a single layer, so only one forward pass + is needed. + + Args: + imgs (list): + List of images to process, of the shape NHWC. + point_coords (list): + List of point coordinates for each image. + box_coords (list): + Bounding box coordinates for each image. + + Returns: + tuple[np.ndarray, np.ndarray]: + Array of masks and scores for each image. + + """ + masks, scores = [], [] + + for i, img in enumerate(imgs): + image = [Image.fromarray(img)] + embeddings, orig_sizes, reshaped_sizes = self._encode_image(image) + point_labels = None + points = None + boxes = None + + if box_coords is not None: + boxes = box_coords[i] + # Convert box coordinates to list + boxes = [boxes[:, None, :].tolist()] + image_masks, image_scores = self._process_prompts( + image, + embeddings, + orig_sizes, + reshaped_sizes, + None, + boxes, + point_labels, + ) + masks.append(np.array([image_masks])) + scores.append(np.array([image_scores])) + + if point_coords is not None: + points = point_coords[i] + # Convert point coordinates to list + point_labels = np.ones((1, len(points), 1), dtype=int).tolist() + points = [points[:, None, :].tolist()] + image_masks, image_scores = self._process_prompts( + image, + embeddings, + orig_sizes, + reshaped_sizes, + points, + None, + point_labels, + ) + masks.append(np.array([image_masks])) + scores.append(np.array([image_scores])) + + torch.cuda.empty_cache() + + return np.concatenate(masks, axis=2), np.concatenate(scores, axis=2) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: list, + point_coords: np.ndarray | None = None, + box_coords: np.ndarray | None = None, + *, + device: str = "cpu", + ) -> tuple[np.ndarray, np.ndarray]: + """Run inference on an input batch. + + Contains logic for forward operation as well as I/O aggregation. + SAM accepts a list of points and a single bounding box per image. + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (list): + A batch of data generated by + `torch.utils.data.DataLoader`. + point_coords (np.ndarray | None): + Point coordinates for each image in the batch. + box_coords (np.ndarray | None): + Bounding box coordinates for each image in the batch. + device (str): + Device to run inference on. + + Returns: + pred_info (tuple[np.ndarray, np.ndarray]): + Tuple of masks and scores for each image in the batch. + + """ + model.eval().to(device) + if point_coords is None and box_coords is None: + msg = "At least one of point_coords or box_coords must be provided." + raise ValueError(msg) + + with torch.inference_mode(): + masks, scores = model(batch_data, point_coords, box_coords) + + return masks, scores + + def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray: + """Encodes image and stores size info for later mask post-processing.""" + processed = self.processor(image, return_tensors="pt") + original_sizes = processed["original_sizes"] + reshaped_sizes = processed["reshaped_input_sizes"] + + inputs = processed.to(self.device) + embeddings = self.model.get_image_embeddings(inputs["pixel_values"]) + return embeddings, original_sizes, reshaped_sizes + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Pre-processes an image - Converts it into a format accepted by SAM (HWC).""" + # Move the tensor to the CPU if it's a PyTorch tensor + if isinstance(image, torch.Tensor): + image = image.permute(1, 2, 0).cpu().numpy() + + return image[..., :3] # Remove alpha channel if present + + def to( + self: ModelABC, + device: str = "cpu", + dtype: torch.dtype | None = None, + *, + non_blocking: bool = False, + ) -> ModelABC | torch.nn.DataParallel[ModelABC]: + """Moves the model to the specified device.""" + super().to(device, dtype=dtype, non_blocking=non_blocking) + self.device = device + self.model.to(device) + return self diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index 8c3d49473..b3af1e557 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -62,6 +62,7 @@ class ModelIOConfigABC: patch_input_shape: list[int] | np.ndarray | tuple[int, int] stride_shape: list[int] | np.ndarray | tuple[int, int] = None output_resolutions: list[dict] = field(default_factory=list) + ignore_index: int | None = (None,) def __post_init__(self: ModelIOConfigABC) -> None: """Perform post initialization tasks.""" diff --git a/tiatoolbox/models/engine/prompt_segmentor.py b/tiatoolbox/models/engine/prompt_segmentor.py new file mode 100644 index 000000000..608251c7c --- /dev/null +++ b/tiatoolbox/models/engine/prompt_segmentor.py @@ -0,0 +1,115 @@ +"""This module enables interactive segmentation.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.utils.misc import dict_to_store_semantic_segmentor + +if TYPE_CHECKING: # pragma: no cover + import torch + + from tiatoolbox.type_hints import IntPair + + +class PromptSegmentor: + """Engine for prompt-based segmentation of WSIs. + + This class is designed to work with the SAM model architecture. + It allows for interactive segmentation by providing point and bounding box + coordinates as prompts. The model is intended to be used with image tiles + selected interactively in some way and provided as np.arrays. At least + one of either point_coords or box_coords must be provided to guide + segmentation. + + Args: + model (SAM): + Model architecture to use. If None, defaults to SAM. + + """ + + def __init__( + self, + model: torch.nn.Module = None, + ) -> None: + """Initializes the PromptSegmentor.""" + model = SAM() if model is None else model + self.model = model + self.scale = 1.0 + self.offset = np.array([0, 0]) + + def run( # skipcq: PYL-W0221 + self, + images: list, + point_coords: np.ndarray | None = None, + box_coords: np.ndarray | None = None, + save_dir: str | Path | None = None, + device: str = "cpu", + ) -> list[Path]: + """Run inference on image patches with prompts. + + Args: + images (list): + List of image patch arrays to run inference on. + point_coords (np.ndarray): + N_im x N_points x 2 array of point coordinates for each image patch. + box_coords (np.ndarray): + N_im x N_boxes x 4 array of bounding box coordinates for each + image patch. + save_dir (str or Path): + Directory to save the output databases. + device (str): + Device to run inference on. + + Returns: + list[Path]: + Paths to the saved output databases. + + """ + paths = [] + masks, _ = self.model.infer_batch( + self.model, + images, + point_coords=point_coords, + box_coords=box_coords, + device=device, + ) + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + for i, _mask in enumerate(masks): + mask = np.any(_mask[0], axis=0, keepdims=False) + dict_to_store_semantic_segmentor( + patch_output={"predictions": mask[0]}, + scale_factor=(self.scale, self.scale), + offset=self.offset, + save_path=Path(f"{save_dir}/{i}.db"), + output_type="annotationstore", + ignore_index=0, + ) + paths.append(Path(f"{save_dir}/{i}.db")) + return paths + + def calc_mpp( + self, area_dims: IntPair, base_mpp: float, fixed_size: int = 1500 + ) -> tuple[float, float]: + """Calculates the microns per pixel for a fixed area of an image. + + Args: + area_dims (tuple): + Dimensions of the area to be scaled. + base_mpp (float): + Microns per pixel of the base image. + fixed_size (int): + Fixed size of the area. + + Returns: + tuple[float, float]: + Tuple of the scaled mpp and the scale factor. + """ + scale = max(area_dims) / fixed_size if max(area_dims) > fixed_size else 1.0 + self.scale = scale + return base_mpp * scale, scale diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 19fae3d69..c7748b597 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -833,6 +833,7 @@ def save_predictions( output_type=output_type, class_dict=class_dict, save_path=output_path, + ignore_index=self._ioconfig.ignore_index, verbose=self.verbose, ) @@ -844,6 +845,7 @@ def save_predictions( output_type=output_type, class_dict=class_dict, save_path=save_path.with_suffix(suffix), + ignore_index=self._ioconfig.ignore_index, verbose=self.verbose, ) save_paths = out_file diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 22b6f3e4b..bac6a0682 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1313,6 +1313,7 @@ def process_contours( contours: list[np.ndarray], hierarchy: np.ndarray, scale_factor: tuple[float, float] = (1, 1), + offset: np.ndarray | None = None, properties: dict[str, JSON] | None = None, ) -> list[Annotation]: """Process contours and hierarchy to create annotations. @@ -1324,6 +1325,8 @@ def process_contours( A list of hierarchy. scale_factor (tuple[float, float]): The scale factor to use when loading the annotations. + offset (np.ndarray | None): + Optional offset to be added to the coordinates of the annotations. properties (dict | None): Optional properties to include with each annotation type. @@ -1342,6 +1345,8 @@ def process_contours( for i, layer_ in enumerate(contours): coords: np.ndarray = layer_.squeeze() scaled_coords: np.ndarray = np.array([np.array(scale_factor) * coords]) + if offset is not None: + scaled_coords += offset # save one points as a line, otherwise save the Polygon if len(layer_) > 2: # noqa: PLR2004 @@ -1420,7 +1425,9 @@ def dict_to_store_semantic_segmentor( output_type: str, class_dict: dict | None = None, save_path: Path | None = None, + offset: np.ndarray | None = None, *, + ignore_index: int | None = None, verbose: bool = True, ) -> AnnotationStore | dict | Path: """Converts output of TIAToolbox SemanticSegmentor engine to AnnotationStore. @@ -1441,6 +1448,12 @@ def dict_to_store_semantic_segmentor( save_path (str or Path): Optional Output directory to save the Annotation Store results. + offset (np.ndarray | None): + Optional offset to be added to the coordinates of the annotations. + ignore_index (int | None): + Any index to ignore in the layer list. e.g., background. + Defaults to 0 (background). If None, all the layers are saved to + the annotationstore or JSON file. verbose (bool): Whether to display logs and progress bar. @@ -1453,8 +1466,10 @@ def dict_to_store_semantic_segmentor( """ preds = da.from_array(patch_output["predictions"], chunks="auto") + ignore_index = -1 if ignore_index is None else ignore_index # Get the number of unique predictions layer_list = da.unique(preds).compute() + layer_list = np.delete(layer_list, np.where(layer_list == ignore_index)) if class_dict is None: class_dict = {int(i): int(i) for i in layer_list.tolist()} @@ -1475,6 +1490,7 @@ def dict_to_store_semantic_segmentor( scale_factor=scale_factor, class_dict=class_dict, save_path=save_path, + offset=offset, verbose=verbose, ) @@ -1566,6 +1582,7 @@ def _semantic_segmentations_as_annotations( scale_factor: tuple[float, float], class_dict: dict, save_path: Path | None = None, + offset: np.ndarray | None = None, *, verbose: bool = True, ) -> AnnotationStore | Path: @@ -1593,7 +1610,11 @@ def _semantic_segmentations_as_annotations( contours = cast("list[np.ndarray]", contours) annotations_list_ = process_contours( - contours, hierarchy, scale_factor, {"type": class_label, "class": class_id} + contours=contours, + hierarchy=hierarchy, + scale_factor=scale_factor, + offset=offset, + properties={"type": class_label, "class": class_id}, ) annotations_list.extend(annotations_list_) diff --git a/tiatoolbox/visualization/bokeh_app/main.py b/tiatoolbox/visualization/bokeh_app/main.py index 29aa0c0ad..4087b2178 100644 --- a/tiatoolbox/visualization/bokeh_app/main.py +++ b/tiatoolbox/visualization/bokeh_app/main.py @@ -29,6 +29,7 @@ ColorPicker, Column, ColumnDataSource, + CustomAction, CustomJS, CustomJSTickFormatter, DataTable, @@ -68,8 +69,9 @@ # GitHub actions seems unable to find TIAToolbox unless this is here sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) from tiatoolbox import logger -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - NucleusInstanceSegmentor, +from tiatoolbox.models.engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from tiatoolbox.models.engine.prompt_segmentor import ( # skipcq: FLK-E402 + PromptSegmentor, ) from tiatoolbox.tools.pyramid import ZoomifyGenerator from tiatoolbox.utils.misc import select_device @@ -1096,6 +1098,23 @@ def slide_select_cb(attr: str, old: str, new: str) -> None: # noqa: ARG001 layer_drop_cb(dummy_attr) +def clear_overlay_cb(attr: str) -> None: # noqa: ARG001 + """Clear all overlays and reset to just the slide.""" + UI["pt_source"].data = {"x": [], "y": []} + UI["box_source"].data = {"x": [], "y": [], "width": [], "height": []} + UI["node_source"].data = {"x_": [], "y_": [], "node_color_": []} + UI["edge_source"].data = {"x0_": [], "y0_": [], "x1_": [], "y1_": []} + UI["hover"].tooltips = None + if len(UI["p"].renderers) > N_PERMANENT_RENDERERS: + for r in UI["p"].renderers[N_PERMANENT_RENDERERS:].copy(): + UI["p"].renderers.remove(r) + UI["vstate"].layer_dict = {"slide": 0, "rect": 1, "pts": 2, "nodes": 3, "edges": 4} + UI["color_column"].children = [] + UI["type_column"].children = [] + UI["s"].put(f"http://{host2}:{port}/tileserver/clear_overlays") + change_tiles("slide") + + def handle_graph_layer(attr: MenuItemClick) -> None: # skipcq: PY-R1000 """Handle adding a graph layer.""" do_feats = False @@ -1322,6 +1341,8 @@ def to_model_cb(attr: ButtonClick) -> None: # noqa: ARG001 """Callback to run currently selected model.""" if UI["vstate"].current_model == "hovernet": segment_on_box() + elif UI["vstate"].current_model == "SAM": + sam_segment() # Add any other models here else: # pragma: no cover logger.warning("unknown model") @@ -1479,6 +1500,103 @@ def segment_on_box() -> None: rmtree(tmp_mask_dir) +def sam_segment() -> None: + """Callback to run SAM using a point on the slide. + + Will run PromptSegmentor on selected region of wsi defined + by the point in pt_source. + + """ + prompt_segmentor = PromptSegmentor() + x_start = max(0, UI["p"].x_range.start) + y_start = max(0, -UI["p"].y_range.end) + x_end = min(UI["p"].x_range.end, UI["vstate"].dims[0]) + y_end = min(-UI["p"].y_range.start, UI["vstate"].dims[1]) + offset = np.array([x_start, y_start]) + prompt_segmentor.offset = offset + + height = y_end - y_start + width = x_end - x_start + res, scale_factor = prompt_segmentor.calc_mpp( + (width, height), UI["vstate"].mpp[0], 1500 + ) + + # Get point coordinates + x = np.round(UI["pt_source"].data["x"]) + y = np.round(UI["pt_source"].data["y"]) + point_coords = ( + ( + np.array([[[x[i], -y[i]] for i in range(len(x))]], np.uint32) + - np.array([[x_start, y_start]]) + ) + / scale_factor + if len(x) > 0 + else None + ) + + # Get box coordinates + x = np.round(UI["box_source"].data["x"]) + y = np.round(UI["box_source"].data["y"]) + + x = [ + round(UI["box_source"].data["x"][i] - 0.5 * UI["box_source"].data["width"][i]) + for i in range(len(x)) + ] + y = [ + -round(UI["box_source"].data["y"][i] + 0.5 * UI["box_source"].data["height"][i]) + for i in range(len(y)) + ] + width = [round(UI["box_source"].data["width"][i]) for i in range(len(x))] + height = [round(UI["box_source"].data["height"][i]) for i in range(len(x))] + box_coords = ( + ( + np.array( + [ + [ + [x[i], y[i], x[i] + width[i], height[i] + y[i]] + for i in range(len(x)) + ] + ], + np.uint32, + ) + - np.array( + [[x_start, y_start, x_start, y_start]], + ) + ) + / scale_factor + if len(x) > 0 + else None + ) + + tmp_save_dir = Path(tempfile.mkdtemp(suffix="bokeh_temp")) + + # read the region of interest from the slide + roi = UI["vstate"].wsi.read_bounds( + (int(x_start), int(y_start), int(x_end), int(y_end)), + resolution=res, + units="mpp", + ) + + # Run SAM on the point + prediction = prompt_segmentor.run( + images=[roi], + device=select_device(on_gpu=torch.cuda.is_available()), + save_dir=tmp_save_dir, + point_coords=point_coords, + box_coords=box_coords, + ) + + ann_loc = str(prediction[0]) + + fname = make_safe_name(ann_loc) + resp = UI["s"].put( + f"http://{host2}:{port}/tileserver/overlay", + data={"overlay_path": fname}, + ) + ann_types = json.loads(resp.text) + update_ui_on_new_annotations(ann_types) + + # endregion # Set up main window @@ -1695,7 +1813,7 @@ def gather_ui_elements( # noqa: PLR0915 button_type="success", width=80, max_width=90, - height=35, + height=45, sizing_mode="stretch_width", name=f"to_model{win_num}", ) @@ -1707,7 +1825,7 @@ def gather_ui_elements( # noqa: PLR0915 ) model_drop = Select( title="choose model:", - options=["hovernet"], + options=["hovernet", "SAM"], height=25, width=120, max_width=120, @@ -1720,10 +1838,18 @@ def gather_ui_elements( # noqa: PLR0915 button_type="success", max_width=90, width=80, - height=35, + height=45, sizing_mode="stretch_width", name=f"save_button{win_num}", ) + clear_button = Button( + label="Clear Overlays", + button_type="warning", + width=120, + height=40, + sizing_mode="stretch_width", + name=f"clear_button{win_num}", + ) type_cprop_tt = Tooltip( content=HTML( """Select a type of object, and a property to color by. Objects of @@ -1791,6 +1917,7 @@ def gather_ui_elements( # noqa: PLR0915 filter_input.on_change("value", filter_input_cb) cprop_input.on_change("value", cprop_input_cb) type_cmap_select.on_change("value", type_cmap_cb) + clear_button.on_click(clear_overlay_cb) # Create some layouts type_column = column(children=layer_boxes, name=f"type_column{win_num}") @@ -1828,6 +1955,7 @@ def gather_ui_elements( # noqa: PLR0915 "cmap_row", "type_cmap_select", "model_row", + "clear_button", "type_select_row", ], [ @@ -1840,18 +1968,19 @@ def gather_ui_elements( # noqa: PLR0915 cmap_row, type_cmap_select, model_row, + clear_button, type_select_row, ], strict=False, ), ) if "ui_elements_1" in doc_config: - # Only add the elements specified in config file + # Dont add elements specified 0 in config file ui_layout = column( [ ui_elements_1[el] - for el in doc_config["ui_elements_1"] - if doc_config["ui_elements_1"][el] == 1 + for el in ui_elements_1 + if doc_config["ui_elements_1"].get(el, 1) == 1 ], sizing_mode="stretch_width", ) @@ -1887,7 +2016,7 @@ def gather_ui_elements( # noqa: PLR0915 [ ui_elements_2[el] for el in doc_config["ui_elements_2"] - if doc_config["ui_elements_2"][el] == 1 + if doc_config["ui_elements_2"].get(el, 1) == 1 ], ) else: @@ -2013,6 +2142,22 @@ def make_window(vstate: ViewerState) -> dict: # noqa: PLR0915 p.add_tools(BoxEditTool(renderers=[r], num_objects=1)) p.add_tools(PointDrawTool(renderers=[c])) p.add_tools(TapTool()) + clear_code = """ + box_source.clear() + pt_source.clear() + """ + p.add_tools( + CustomAction( + callback=CustomJS( + args={ + "box_source": box_source, + "pt_source": pt_source, + }, + code=clear_code, + ), + description="Clear", + ), + ) if get_from_config(["opts", "hover_on"], 0) == 0: p.toolbar.active_inspect = None diff --git a/tiatoolbox/visualization/tileserver.py b/tiatoolbox/visualization/tileserver.py index d726cb79c..1205776b3 100644 --- a/tiatoolbox/visualization/tileserver.py +++ b/tiatoolbox/visualization/tileserver.py @@ -144,6 +144,7 @@ def __init__( # noqa: PLR0915 self.route("/tileserver/session_id")(self.session_id) self.route("/tileserver/color_prop", methods=["PUT"])(self.change_prop) self.route("/tileserver/slide", methods=["PUT"])(self.change_slide) + self.route("/tileserver/clear_overlays", methods=["PUT"])(self.clear_overlays) self.route("/tileserver/cmap", methods=["PUT"])(self.change_mapper) self.route( "/tileserver/annotations", @@ -421,6 +422,16 @@ def change_slide(self: TileServer) -> str: return "done" + def clear_overlays(self: TileServer) -> str: + """Clear all overlays.""" + session_id = self._get_session_id() + slide_layer = self.layers[session_id]["slide"] + self.layers[session_id] = {"slide": slide_layer} + self.pyramids[session_id] = { + "slide": ZoomifyGenerator(slide_layer, tile_size=256), + } + return "done" + def change_mapper(self: TileServer) -> str: """Change the colour mapper for the overlay.""" session_id = self._get_session_id() @@ -714,6 +725,7 @@ def commit_db(self: TileServer) -> str: if ( layer.store.path.suffix == ".db" and layer.store.path.name != f"temp_{session_id}.db" + and not str(layer.store.path.parent.name).endswith("bokeh_temp") ): logger.info("%s*.db committed.", layer.store.path.stem) layer.store.commit() @@ -815,9 +827,8 @@ def get_channels(self: TileServer) -> Response: session_id = self._get_session_id() if isinstance(self.layers[session_id]["slide"].post_proc, MultichannelToRGB): if not self.layers[session_id]["slide"].post_proc.is_validated: - _ = self.layers[session_id]["slide"].slide_thumbnail( - resolution=8.0, units="mpp" - ) + # trigger validation of channels with small read + _ = self.layers[session_id]["slide"].read_rect((0, 0), (100, 100)) return jsonify( { "channels": self.layers[session_id]["slide"].post_proc.color_dict,