diff --git a/docs/source/api/morphoclass.console.cmd_extract_features_and_train.rst b/docs/source/api/morphoclass.console.cmd_extract_features_and_train.rst new file mode 100644 index 0000000..f41e304 --- /dev/null +++ b/docs/source/api/morphoclass.console.cmd_extract_features_and_train.rst @@ -0,0 +1,7 @@ +morphoclass.console.cmd\_extract\_features\_and\_train module +============================================================= + +.. automodule:: morphoclass.console.cmd_extract_features_and_train + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/morphoclass.console.rst b/docs/source/api/morphoclass.console.rst index 67415de..ab3970c 100644 --- a/docs/source/api/morphoclass.console.rst +++ b/docs/source/api/morphoclass.console.rst @@ -9,6 +9,7 @@ Submodules morphoclass.console.cmd_evaluate morphoclass.console.cmd_extract_features + morphoclass.console.cmd_extract_features_and_train morphoclass.console.cmd_morphometrics morphoclass.console.cmd_organise_dataset morphoclass.console.cmd_performance_table diff --git a/setup.cfg b/setup.cfg index 5bb4436..20f60a4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ python_requires = >=3.8,<3.9 install_requires = PyYAML captum + cleanlab==1.0.1 click dash dash-bootstrap-components diff --git a/src/morphoclass/console/cmd_extract_features.py b/src/morphoclass/console/cmd_extract_features.py index c952be5..93eda5e 100644 --- a/src/morphoclass/console/cmd_extract_features.py +++ b/src/morphoclass/console/cmd_extract_features.py @@ -113,6 +113,29 @@ def cli( no_simplify_graph: bool, keep_diagram: bool, force: bool, +) -> None: + """Extract morphology features.""" + return extract_features( + csv_path, + neurite_type, + feature, + output_dir, + orient, + no_simplify_graph, + keep_diagram, + force, + ) + + +def extract_features( + csv_path: StrPath, + neurite_type: str, + feature: str, + output_dir: StrPath, + orient: bool, + no_simplify_graph: bool, + keep_diagram: bool, + force: bool, ) -> None: """Extract morphology features.""" output_dir = pathlib.Path(output_dir) diff --git a/src/morphoclass/console/cmd_extract_features_and_train.py b/src/morphoclass/console/cmd_extract_features_and_train.py new file mode 100644 index 0000000..dfa4338 --- /dev/null +++ b/src/morphoclass/console/cmd_extract_features_and_train.py @@ -0,0 +1,143 @@ +# Copyright © 2022-2022 Blue Brain Project/EPFL +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the `morphoclass train` CLI command.""" +from __future__ import annotations + +import logging +import pathlib +from typing import Literal + +import click + +from morphoclass.types import StrPath + +logger = logging.getLogger(__name__) + + +@click.command(name="train", help="Train a morphology classification model.") +@click.argument("csv_path", type=click.Path(dir_okay=False)) +@click.argument("neurite_type", type=click.Choice(["apical", "axon", "basal", "all"])) +@click.argument( + "feature", + type=click.Choice( + [ + "graph-rd", + "graph-proj", + "diagram-tmd-rd", + "diagram-tmd-proj", + "diagram-deepwalk", + "image-tmd-rd", + "image-tmd-proj", + "image-deepwalk", + ] + ), +) +@click.option( + "--orient", + is_flag=True, + help="Orient the neurons so that the apicals are aligned with the positive y-axis.", +) +@click.option( + "--no-simplify-graph", + is_flag=True, + help=""" + By default the neurite graph is reduced to branching nodes only. With this + flag the full neurite graph will be preserved. + """, +) +@click.option( + "--keep-diagram", + is_flag=True, + help="After converting the diagram to persistence image don't discard the diagram.", +) +@click.option( + "--model-config", + type=click.Path(exists=True, dir_okay=False), + required=True, + help=""" + The model configuration file. + For inspiration, model configuration files can be found under + dvc/training/configs/ + """, +) +@click.option( + "--splitter-config", + type=click.Path(exists=True, dir_okay=False), + required=True, + help=""" + The splitter configuration file. + For inspiration, splitter configuration files can be found under + dvc/training/configs/ + """, +) +@click.option( + "--output-dir", + type=click.Path(file_okay=False), + required=True, + help="The output directory.", +) +@click.option( + "-f", + "--force", + type=click.BOOL, + default=False, + is_flag=True, + help="Don't ask for overwriting existing output files.", +) +def cli( + csv_path: StrPath, + neurite_type: Literal["apical", "axon", "basal", "all"], + feature: Literal[ + "graph-rd", + "graph-proj", + "diagram-tmd-rd", + "diagram-tmd-proj", + "diagram-deepwalk", + "image-tmd-rd", + "image-tmd-proj", + "image-deepwalk", + ], + orient: bool, + no_simplify_graph: bool, + keep_diagram: bool, + model_config: StrPath, + splitter_config: StrPath, + output_dir: StrPath, + force: bool, +) -> None: + """Extract features and train the model.""" + from morphoclass.console.cmd_extract_features import extract_features + from morphoclass.console.cmd_train import train + + input_csv = pathlib.Path(csv_path).resolve() + output_dir = pathlib.Path(output_dir).resolve() + + extract_features( + input_csv, + neurite_type, + feature, + output_dir / "features", + orient, + no_simplify_graph, + keep_diagram, + force, + ) + + train( + output_dir / "features", + model_config, + splitter_config, + output_dir / "checkpoints", + force, + ) diff --git a/src/morphoclass/console/cmd_train.py b/src/morphoclass/console/cmd_train.py index 72d9ee0..56edb3d 100644 --- a/src/morphoclass/console/cmd_train.py +++ b/src/morphoclass/console/cmd_train.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the `morphoclass train` CLI command.""" +"""Implementation of the `morphoclass train-after-extraction` CLI command.""" from __future__ import annotations import logging @@ -25,7 +25,12 @@ logger = logging.getLogger(__name__) -@click.command(name="train", help="Train a morphology classification model.") +@click.command( + name="train-after-extraction", + help=""" + Train a morphology classification model. + Features need to be first extracted.""", +) @click.option( "--features-dir", type=click.Path(exists=True, file_okay=False), @@ -64,6 +69,23 @@ def cli( splitter_config: StrPath, checkpoint_dir: StrPath, force: bool, +) -> None: + """Training and evaluation of the model.""" + return train( + features_dir, + model_config, + splitter_config, + checkpoint_dir, + force, + ) + + +def train( + features_dir: StrPath, + model_config: StrPath, + splitter_config: StrPath, + checkpoint_dir: StrPath, + force: bool, ) -> None: """Training and evaluation of the model. diff --git a/src/morphoclass/console/main.py b/src/morphoclass/console/main.py index 886d2f8..3566c7a 100644 --- a/src/morphoclass/console/main.py +++ b/src/morphoclass/console/main.py @@ -22,6 +22,7 @@ import morphoclass from morphoclass.console import cmd_evaluate from morphoclass.console import cmd_extract_features +from morphoclass.console import cmd_extract_features_and_train from morphoclass.console import cmd_morphometrics from morphoclass.console import cmd_organise_dataset from morphoclass.console import cmd_performance_table @@ -143,3 +144,4 @@ def cli(verbose: int, log_file_path: pathlib.Path | None) -> None: cli.add_command(cmd_performance_table.cli) cli.add_command(cmd_extract_features.cli) cli.add_command(cmd_morphometrics.cli) +cli.add_command(cmd_extract_features_and_train.cli)