diff --git a/.github/workflows/reusable-test.yaml b/.github/workflows/reusable-test.yaml index 502a0d35a..80b2a4ff4 100644 --- a/.github/workflows/reusable-test.yaml +++ b/.github/workflows/reusable-test.yaml @@ -35,6 +35,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: version: "0.10.0" + enable-cache: "true" - name: Install dependencies for Python ${{ matrix.python-version }} run: | diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index d19eda75b..e45d8cbbf 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -50,6 +50,19 @@ "title": "Trust Remote Code", "type": "boolean" }, + "revision": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Revision from HF repo", + "title": "Revision" + }, "train_head": { "default": false, "description": "Whether to train the head of the model. If False, LogReg will be trained.", @@ -262,6 +275,19 @@ "description": "Whether to trust the remote code when loading the model.", "title": "Trust Remote Code", "type": "boolean" + }, + "revision": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Revision from HF repo", + "title": "Revision" } }, "title": "HFModelConfig", @@ -515,6 +541,7 @@ "truncation": true }, "trust_remote_code": false, + "revision": null, "train_head": false } }, @@ -531,7 +558,8 @@ "padding": true, "truncation": true }, - "trust_remote_code": false + "trust_remote_code": false, + "revision": "refs/pr/16" } }, "hpo_config": { diff --git a/src/autointent/_optimization_config.py b/src/autointent/_optimization_config.py index b78189b29..25b591be1 100644 --- a/src/autointent/_optimization_config.py +++ b/src/autointent/_optimization_config.py @@ -11,6 +11,7 @@ HFModelConfig, HPOConfig, LoggingConfig, + get_default_hfmodel_config, initialize_embedder_config, ) @@ -40,7 +41,7 @@ def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401 cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig() - transformer_config: HFModelConfig = HFModelConfig() + transformer_config: HFModelConfig = get_default_hfmodel_config() hpo_config: HPOConfig = HPOConfig() diff --git a/src/autointent/_pipeline/_pipeline.py b/src/autointent/_pipeline/_pipeline.py index 3d62f07d1..a764f1d78 100644 --- a/src/autointent/_pipeline/_pipeline.py +++ b/src/autointent/_pipeline/_pipeline.py @@ -22,6 +22,7 @@ LoggingConfig, VectorIndexConfig, get_default_embedder_config, + get_default_hfmodel_config, get_default_vector_index_config, ) from autointent.custom_types import NodeType @@ -64,7 +65,7 @@ def __init__( self.embedder_config = get_default_embedder_config() self.cross_encoder_config = CrossEncoderConfig() self.data_config = DataConfig() - self.transformer_config = HFModelConfig() + self.transformer_config = get_default_hfmodel_config() self.hpo_config = HPOConfig() self.vector_index_config = get_default_vector_index_config() elif not isinstance(nodes[0], InferenceNode): diff --git a/src/autointent/_wrappers/embedder/sentence_transformers.py b/src/autointent/_wrappers/embedder/sentence_transformers.py index 0948ea406..e19833e52 100644 --- a/src/autointent/_wrappers/embedder/sentence_transformers.py +++ b/src/autointent/_wrappers/embedder/sentence_transformers.py @@ -84,6 +84,7 @@ def _load_model(self) -> SentenceTransformer: prompts=self.config.get_prompt_config(), similarity_fn_name=self.config.similarity_fn_name, trust_remote_code=self.config.trust_remote_code, + revision=self.config.revision, ) self._model = res return self._model diff --git a/src/autointent/configs/__init__.py b/src/autointent/configs/__init__.py index b0495c372..01f7f3106 100644 --- a/src/autointent/configs/__init__.py +++ b/src/autointent/configs/__init__.py @@ -17,6 +17,7 @@ EmbedderFineTuningConfig, HFModelConfig, TokenizerConfig, + get_default_hfmodel_config, ) from ._vector_index import FaissConfig, OpenSearchConfig, VectorIndexConfig, get_default_vector_index_config @@ -40,6 +41,7 @@ "VectorIndexConfig", "VocabConfig", "get_default_embedder_config", + "get_default_hfmodel_config", "get_default_vector_index_config", "initialize_embedder_config", ] diff --git a/src/autointent/configs/_transformers.py b/src/autointent/configs/_transformers.py index 9bdc9597d..90bed58e8 100644 --- a/src/autointent/configs/_transformers.py +++ b/src/autointent/configs/_transformers.py @@ -53,6 +53,7 @@ class HFModelConfig(BaseModel): fp16: bool = Field(False, description="Whether to use mixed precision training (not all devices support this).") tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig) trust_remote_code: bool = Field(False, description="Whether to trust the remote code when loading the model.") + revision: str | None = Field(None, description="Revision from HF repo") @classmethod def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self: @@ -75,6 +76,10 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> return cls(**values) +def get_default_hfmodel_config() -> HFModelConfig: + return HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + + class CrossEncoderConfig(HFModelConfig): model_name: str = Field("cross-encoder/ms-marco-MiniLM-L6-v2", description="Name of the hugging face model.") train_head: bool = Field( diff --git a/src/autointent/context/_context.py b/src/autointent/context/_context.py index 7be00ac22..7410ec4b3 100644 --- a/src/autointent/context/_context.py +++ b/src/autointent/context/_context.py @@ -16,6 +16,7 @@ HPOConfig, LoggingConfig, VectorIndexConfig, + get_default_hfmodel_config, ) from .data_handler import DataHandler @@ -25,9 +26,7 @@ from pathlib import Path from autointent import Dataset - from autointent.configs import ( - DataConfig, - ) + from autointent.configs import DataConfig class Context: @@ -202,4 +201,4 @@ def resolve_transformer(self) -> HFModelConfig: """ if hasattr(self, "transformer_config"): return self.transformer_config - return HFModelConfig() + return get_default_hfmodel_config() diff --git a/src/autointent/context/data_handler/__init__.py b/src/autointent/context/data_handler/__init__.py index a2e91358f..24a3d3b55 100644 --- a/src/autointent/context/data_handler/__init__.py +++ b/src/autointent/context/data_handler/__init__.py @@ -1,4 +1,15 @@ from ._data_handler import DataHandler -from ._stratification import StratifiedSplitter, split_dataset +from ._stratification import ( + SplitReadinessResult, + StratifiedSplitter, + check_split_readiness, + split_dataset, +) -__all__ = ["DataHandler", "StratifiedSplitter", "split_dataset"] +__all__ = [ + "DataHandler", + "SplitReadinessResult", + "StratifiedSplitter", + "check_split_readiness", + "split_dataset", +] diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 1482383ab..3ac3a3b59 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -7,6 +7,8 @@ from __future__ import annotations import logging +from collections import Counter +from dataclasses import dataclass from typing import TYPE_CHECKING import numpy as np @@ -16,7 +18,7 @@ from transformers import set_seed if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from datasets import Dataset as HFDataset from numpy import typing as npt @@ -27,6 +29,37 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class StratifyInputs: + """Inputs for stratified splitting: the effective dataset and post-split hook. + + Used internally so that the same OOS/mapping logic feeds both splitting and + readiness checks. + """ + + dataset: HFDataset + multilabel: bool + test_size: float + post_split_fn: Callable[[HFDataset, HFDataset], tuple[HFDataset, HFDataset]] + + +@dataclass(frozen=True) +class SplitReadinessResult: + """Result of checking whether a dataset can be stratified split. + + Attributes: + ready: True if stratification can be performed (enough samples per class). + underpopulated_classes: List of (label, count) for classes below the minimum. + min_samples_per_class_required: Minimum samples per class used for the check. + reason: Human-readable reason when not ready (e.g. OOS not configured). + """ + + ready: bool + underpopulated_classes: list[tuple[int | str | None, int]] + min_samples_per_class_required: int + reason: str | None + + class StratifiedSplitter: """A class for stratified splitting of datasets. @@ -77,25 +110,9 @@ def __call__( Raises: ValueError: If OOS samples are present but allow_oos_in_train is not specified. """ - if not self._has_oos_samples(dataset): - train, test = self._split_without_oos(dataset, multilabel, self.test_size) - if self.is_few_shot: - train, test = create_few_shot_split( - train, - test, - multilabel=multilabel, - label_column=self.label_feature, - examples_per_label=self.examples_per_label, - ) - return train, test - if allow_oos_in_train is None: - msg = ( - "Error while splitting dataset. It contains OOS samples, " - "you need to set the parameter allow_oos_in_train." - ) - raise ValueError(msg) - splitter = self._split_allow_oos_in_train if allow_oos_in_train else self._split_disallow_oos_in_train - train, test = splitter(dataset, multilabel) + inputs = self.get_stratify_inputs(dataset, multilabel, allow_oos_in_train) + train, test = self._split_without_oos(inputs.dataset, inputs.multilabel, inputs.test_size) + train, test = inputs.post_split_fn(train, test) if self.is_few_shot: train, test = create_few_shot_split( train, @@ -106,7 +123,7 @@ def __call__( ) return train, test - def _has_oos_samples(self, dataset: HFDataset) -> bool: + def has_oos_samples(self, dataset: HFDataset) -> bool: """Check if the dataset contains out-of-scope samples. Args: @@ -118,6 +135,105 @@ def _has_oos_samples(self, dataset: HFDataset) -> bool: oos_samples = dataset.filter(lambda sample: sample[self.label_feature] is None) return len(oos_samples) > 0 + def get_stratify_inputs( + self, dataset: HFDataset, multilabel: bool, allow_oos_in_train: bool | None + ) -> StratifyInputs: + """Return the effective dataset and post-split hook for stratification. + + Single source of truth for OOS handling: both splitting and readiness + checks use this so logic is not duplicated. + + Args: + dataset: The input dataset (may contain OOS). + multilabel: Whether the dataset is multi-label. + allow_oos_in_train: Whether OOS samples are allowed in the train split. + Must be set when the dataset contains OOS samples. + + Returns: + StratifyInputs with the dataset to stratify on and a post_split_fn. + + Raises: + ValueError: If the dataset contains OOS samples and allow_oos_in_train is None. + """ + if self.has_oos_samples(dataset) and allow_oos_in_train is None: + msg = ( + "Error while splitting dataset. It contains OOS samples, " + "you need to set the parameter allow_oos_in_train." + ) + raise ValueError(msg) + if not self.has_oos_samples(dataset): + return StratifyInputs( + dataset=dataset, + multilabel=multilabel, + test_size=self.test_size, + post_split_fn=lambda train_ds, test_ds: (train_ds, test_ds), + ) + if allow_oos_in_train: + return self._stratify_inputs_allow_oos(dataset, multilabel) + return self._stratify_inputs_disallow_oos(dataset, multilabel) + + def _stratify_inputs_allow_oos(self, dataset: HFDataset, multilabel: bool) -> StratifyInputs: + """Return stratify inputs when OOS samples are allowed in the train split. + + OOS is mapped to a class so it is stratified; post_split_fn unmaps it. + """ + if multilabel: + in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) + n_classes = len(in_domain_sample[self.label_feature]) + mapped_dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) + + def unmap_oos_multilabel(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + return ( + train_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), + test_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), + ) + + return StratifyInputs( + dataset=mapped_dataset, + multilabel=False, + test_size=self.test_size, + post_split_fn=unmap_oos_multilabel, + ) + oos_class_id = len(dataset.unique(self.label_feature)) - 1 + mapped_dataset = dataset.map(self._map_label, fn_kwargs={"old": None, "new": oos_class_id}) + + def unmap_oos_multiclass(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + return ( + train_ds.map( + self._map_label, + fn_kwargs={"old": oos_class_id, "new": None}, + ), + test_ds.map( + self._map_label, + fn_kwargs={"old": oos_class_id, "new": None}, + ), + ) + + return StratifyInputs( + dataset=mapped_dataset, + multilabel=False, + test_size=self.test_size, + post_split_fn=unmap_oos_multiclass, + ) + + def _stratify_inputs_disallow_oos(self, dataset: HFDataset, multilabel: bool) -> StratifyInputs: + """Return stratify inputs when OOS samples are not allowed in the train split. + + Only in-domain data is stratified; post_split_fn appends OOS to the test set. + """ + in_domain_dataset, out_of_domain_dataset = self._separate_oos(dataset) + adjusted_test_size = self._get_adjusted_test_size(len(dataset), len(out_of_domain_dataset)) + + def concat_oos_to_test(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + return (train_ds, concatenate_datasets([test_ds, out_of_domain_dataset])) + + return StratifyInputs( + dataset=in_domain_dataset, + multilabel=multilabel, + test_size=adjusted_test_size, + post_split_fn=concat_oos_to_test, + ) + def _split_without_oos(self, dataset: HFDataset, multilabel: bool, test_size: float) -> tuple[HFDataset, HFDataset]: """Split dataset that doesn't contain OOS samples. @@ -170,42 +286,6 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np ) return next(splitter.split(np.arange(len(dataset)), np.array(dataset[self.label_feature]))) - def _split_allow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> tuple[HFDataset, HFDataset]: - """Proportionally distribute OOS samples between two splits. - - Internally creates a dataset copy with some integer assigned as OOS class id. - With OOS samples treated as a separate class we obtain proportional distribution - of them between two splits. - - Args: - dataset: Dataset to split. - multilabel: Whether the dataset is multi-label. - - Returns: - A tuple containing training and testing datasets. - """ - # add oos as a class - if multilabel: - in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) - n_classes = len(in_domain_sample[self.label_feature]) - dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) - else: - oos_class_id = len(dataset.unique(self.label_feature)) - 1 - dataset = dataset.map(self._map_label, fn_kwargs={"old": None, "new": oos_class_id}) - - # perform stratified splitting - train, test = self._split_without_oos(dataset, multilabel=False, test_size=self.test_size) - - # remove oos as a class - if multilabel: - train = train.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}) - test = test.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}) - else: - train = train.map(self._map_label, fn_kwargs={"old": oos_class_id, "new": None}) - test = test.map(self._map_label, fn_kwargs={"old": oos_class_id, "new": None}) - - return train, test - def _map_label( self, sample: dict[str, str | LabelType], old: LabelType, new: LabelType ) -> dict[str, str | LabelType]: @@ -253,25 +333,6 @@ def _remove_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) sample[self.label_feature] = None # type: ignore[assignment] return sample - def _split_disallow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> tuple[HFDataset, HFDataset]: - """Move all OOS samples to test split. - - This method preserves the defined test_size proportion so you won't get unexpectedly - large test set even you have lots of OOS samples. - - Args: - dataset: Dataset to split. - multilabel: Whether the dataset is multi-label. - - Returns: - A tuple containing training and testing datasets. - """ - in_domain_dataset, out_of_domain_dataset = self._separate_oos(dataset) - adjusted_test_size = self._get_adjusted_test_size(len(dataset), len(out_of_domain_dataset)) - train, test = self._split_without_oos(in_domain_dataset, multilabel, adjusted_test_size) - test = concatenate_datasets([test, out_of_domain_dataset]) - return train, test - def _separate_oos(self, dataset: HFDataset) -> tuple[HFDataset, HFDataset]: """Separate OOS samples from in-domain samples. @@ -315,6 +376,78 @@ def _get_adjusted_test_size(self, n: int, k: int) -> float: return res +def _check_multiclass_counts( + dataset: HFDataset, label_feature: str, min_samples_per_class: int +) -> list[tuple[int | str | None, int]]: + """Return (label, count) for each class with fewer than min_samples_per_class samples.""" + labels = dataset[label_feature] + counts = Counter(labels) + return [(label, count) for label, count in counts.items() if count < min_samples_per_class] + + +def check_split_readiness( + dataset: Dataset, + split: str, + test_size: float, + min_samples_per_class: int = 2, + allow_oos_in_train: bool | None = None, +) -> SplitReadinessResult: + """Check whether the dataset has enough samples per class for stratified splitting. + + Uses the same OOS and stratification logic as :func:`split_dataset`, so downstream + code can call this before creating a :class:`DataHandler` or calling :func:`split_dataset` + and handle underpopulated classes (e.g. skip phase, log, or fail with a clear message). + + Args: + dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). + split: The split name to check (e.g. ``Split.TRAIN``). + test_size: Proportion used for the test split (must match the value used when splitting). + min_samples_per_class: Minimum number of samples per class required for stratification. + Default 2 matches sklearn's requirement for a 2-way stratified split. + allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples + and this is not set, the function returns ``ready=False`` with a reason. + + Returns: + SplitReadinessResult with ``ready``, ``underpopulated_classes``, and optional ``reason``. + """ + if split not in dataset: + return SplitReadinessResult( + ready=False, + underpopulated_classes=[], + min_samples_per_class_required=min_samples_per_class, + reason=f"Dataset has no split '{split}'.", + ) + hf_split = dataset[split] + splitter = StratifiedSplitter( + test_size=test_size, + label_feature=dataset.label_feature, + random_seed=None, + ) + inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) + if inputs.multilabel: + # Multilabel stratification uses IterativeStratification; we do not validate it here. + return SplitReadinessResult( + ready=True, + underpopulated_classes=[], + min_samples_per_class_required=min_samples_per_class, + reason=None, + ) + underpopulated = _check_multiclass_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) + ready = len(underpopulated) == 0 + reason = None + if not ready: + parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] + reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format( + min_samples_per_class, "; ".join(parts) + ) + return SplitReadinessResult( + ready=ready, + underpopulated_classes=underpopulated, + min_samples_per_class_required=min_samples_per_class, + reason=reason, + ) + + def split_dataset( dataset: Dataset, split: str, diff --git a/src/autointent/modules/scoring/_bert.py b/src/autointent/modules/scoring/_bert.py index 4e80e06ad..5a1c9a962 100644 --- a/src/autointent/modules/scoring/_bert.py +++ b/src/autointent/modules/scoring/_bert.py @@ -143,6 +143,7 @@ def _initialize_model(self) -> Any: # noqa: ANN401 return AutoModelForSequenceClassification.from_pretrained( self.classification_model_config.model_name, trust_remote_code=self.classification_model_config.trust_remote_code, + revision=self.classification_model_config.revision, num_labels=self._n_classes, label2id=label2id, id2label=id2label, @@ -156,7 +157,9 @@ def fit( ) -> None: self._validate_task(labels) - self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained( + self.classification_model_config.model_name, revision=self.classification_model_config.revision + ) self._model = self._initialize_model() tokenized_dataset = self._get_tokenized_dataset(utterances, labels) self._train(tokenized_dataset) @@ -200,8 +203,8 @@ def _train(self, tokenized_dataset: DatasetDict) -> None: callbacks=self._get_trainer_callbacks(), ) if not self.print_progress: - trainer.remove_callback(PrinterCallback) # type: ignore[no-untyped-call] - trainer.remove_callback(ProgressCallback) # type: ignore[no-untyped-call] + trainer.remove_callback(PrinterCallback) + trainer.remove_callback(ProgressCallback) trainer.train() diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index fd63f4396..b42781a04 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -39,6 +39,7 @@ features_type: ["embedding"] embedder_config: - model_name: prajjwal1/bert-tiny + revision: refs/pr/16 - module_name: bert classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 @@ -61,7 +62,9 @@ seed: [0] lora_alpha: [16] - module_name: ptuning - classification_model_config: ["prajjwal1/bert-tiny"] + classification_model_config: + - model_name: "prajjwal1/bert-tiny" + revision: refs/pr/16 num_train_epochs: [1] batch_size: [8, 16] num_virtual_tokens: [10, 20] diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 6dbb0c409..442cc38bf 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -35,6 +35,7 @@ embedder_config: - null - model_name: prajjwal1/bert-tiny + revision: refs/pr/16 - module_name: bert classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 @@ -49,7 +50,9 @@ kernel_sizes: [[3, 4, 5]] num_filters: [100] - module_name: ptuning - classification_model_config: ["prajjwal1/bert-tiny"] + classification_model_config: + - model_name: prajjwal1/bert-tiny + revision: refs/pr/16 num_train_epochs: [1] batch_size: [8] num_virtual_tokens: [10, 20] diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index 1a5c646c8..47a3269a2 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -145,6 +145,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "sts_prompt": None, "query_prompt": None, + "revision": None, "passage_prompt": None, "similarity_fn_name": None, "use_cache": True, @@ -176,6 +177,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "sts_prompt": None, "query_prompt": None, + "revision": None, "passage_prompt": None, "similarity_fn_name": None, "use_cache": True, @@ -203,6 +205,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "sts_prompt": None, "query_prompt": None, + "revision": None, "passage_prompt": None, "similarity_fn_name": None, "use_cache": True, diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py new file mode 100644 index 000000000..962074fc4 --- /dev/null +++ b/tests/data/test_check_split_readiness.py @@ -0,0 +1,231 @@ +"""Unit tests for check_split_readiness and SplitReadinessResult.""" + +import pytest + +from autointent import Dataset +from autointent.context.data_handler import ( + SplitReadinessResult, + check_split_readiness, +) +from autointent.custom_types import Split + + +@pytest.fixture +def dataset_enough_samples(): + """Multiclass dataset with ≥2 samples per class (no OOS). Ready for stratification.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + {"utterance": "c1", "label": 2}, + {"utterance": "c2", "label": 2}, + ], + "test": [ + {"utterance": "t1", "label": 0}, + {"utterance": "t2", "label": 1}, + {"utterance": "t3", "label": 2}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + +@pytest.fixture +def dataset_underpopulated(): + """Multiclass dataset with one class having only 1 sample. Not ready for stratification.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + ], + "test": [ + {"utterance": "t1", "label": 0}, + {"utterance": "t2", "label": 1}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + +@pytest.fixture +def dataset_two_classes_barely_enough(): + """Two classes with exactly 2 samples each. Ready for default min_samples_per_class=2.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + ], + "test": [ + {"utterance": "t1", "label": 0}, + {"utterance": "t2", "label": 1}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + +def test_check_split_readiness_ready_when_enough_samples(dataset_enough_samples): + """When every class has ≥ min_samples_per_class, result is ready.""" + result = check_split_readiness( + dataset_enough_samples, + split=Split.TRAIN, + test_size=0.3, + allow_oos_in_train=False, + ) + assert isinstance(result, SplitReadinessResult) + assert result.ready is True + assert result.underpopulated_classes == [] + assert result.min_samples_per_class_required == 2 + assert result.reason is None + + +def test_check_split_readiness_not_ready_underpopulated(dataset_underpopulated): + """When at least one class has fewer than min samples, result is not ready.""" + result = check_split_readiness( + dataset_underpopulated, + split=Split.TRAIN, + test_size=0.3, + allow_oos_in_train=False, + ) + assert result.ready is False + assert result.min_samples_per_class_required == 2 + assert len(result.underpopulated_classes) == 1 + label, count = result.underpopulated_classes[0] + assert label == 1 + assert count == 1 + assert result.reason is not None + assert "class 1" in result.reason or "1" in result.reason + assert "1 (need 2)" in result.reason + + +def test_check_split_readiness_missing_split(dataset_enough_samples): + """When split is not in dataset, result is not ready with reason.""" + result = check_split_readiness( + dataset_enough_samples, + split="nonexistent_split", + test_size=0.3, + ) + assert result.ready is False + assert result.underpopulated_classes == [] + assert "nonexistent_split" in result.reason + + +def test_check_split_readiness_oos_allow_none(dataset_unsplitted): + """When dataset has OOS and allow_oos_in_train is None, result is not ready.""" + with pytest.raises(ValueError, match="allow_oos_in_train"): + check_split_readiness( + dataset_unsplitted, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=None, + ) + + +def test_check_split_readiness_oos_allow_false_enough_in_domain(dataset_unsplitted): + """With OOS and allow_oos_in_train=False, in-domain classes are checked; clinc subset has enough.""" + result = check_split_readiness( + dataset_unsplitted, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=False, + ) + assert result.ready is True + assert result.underpopulated_classes == [] + assert result.reason is None + + +def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_barely_enough): + """Custom min_samples_per_class is respected.""" + result = check_split_readiness( + dataset_two_classes_barely_enough, + split=Split.TRAIN, + test_size=0.3, + min_samples_per_class=2, + allow_oos_in_train=False, + ) + assert result.ready is True + + result_strict = check_split_readiness( + dataset_two_classes_barely_enough, + split=Split.TRAIN, + test_size=0.3, + min_samples_per_class=3, + allow_oos_in_train=False, + ) + assert result_strict.ready is False + assert len(result_strict.underpopulated_classes) == 2 + assert result_strict.min_samples_per_class_required == 3 + + +def test_check_split_readiness_multilabel_returns_ready(dataset_unsplitted): + """Multilabel datasets return ready=True (multilabel stratification is not validated).""" + dataset = dataset_unsplitted.to_multilabel() + result = check_split_readiness( + dataset, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=False, + ) + assert result.ready is True + assert result.underpopulated_classes == [] + + +def test_check_split_readiness_consistent_with_split_dataset(dataset_enough_samples): + """When check_split_readiness says ready, split_dataset does not raise.""" + result = check_split_readiness( + dataset_enough_samples, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=False, + ) + assert result.ready is True + from autointent.context.data_handler import split_dataset + + train, test = split_dataset( + dataset_enough_samples, + split=Split.TRAIN, + test_size=0.5, + random_seed=42, + allow_oos_in_train=False, + ) + assert len(train) > 0 + assert len(test) > 0 + + +def test_check_split_readiness_underpopulated_implies_split_raises(dataset_underpopulated): + """When check_split_readiness says not ready (underpopulated), split_dataset raises.""" + result = check_split_readiness( + dataset_underpopulated, + split=Split.TRAIN, + test_size=0.3, + allow_oos_in_train=False, + ) + assert result.ready is False + from autointent.context.data_handler import split_dataset + + with pytest.raises(ValueError, match=r"least populated|too few"): + split_dataset( + dataset_underpopulated, + split=Split.TRAIN, + test_size=0.3, + random_seed=42, + allow_oos_in_train=False, + ) diff --git a/tests/modules/scoring/test_bert.py b/tests/modules/scoring/test_bert.py index 86fbc685b..18de9c8ea 100644 --- a/tests/modules/scoring/test_bert.py +++ b/tests/modules/scoring/test_bert.py @@ -5,16 +5,19 @@ import numpy as np import pytest +from autointent.configs import HFModelConfig from autointent.context.data_handler import DataHandler from autointent.modules import BertScorer +_config = HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_bert_scorer_dump_load(dataset): """Test that BertScorer can be saved and loaded while preserving predictions.""" data_handler = DataHandler(dataset) # Create and train scorer - scorer_original = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer_original = BertScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) # Test data @@ -57,7 +60,7 @@ def test_bert_prediction(dataset): """Test that the transformer model can fit and make predictions.""" data_handler = DataHandler(dataset) - scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) @@ -94,7 +97,7 @@ def test_bert_cache_clearing(dataset): """Test that the transformer model properly handles cache clearing.""" data_handler = DataHandler(dataset) - scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) diff --git a/tests/modules/scoring/test_catboost.py b/tests/modules/scoring/test_catboost.py index 177936980..53a28b995 100644 --- a/tests/modules/scoring/test_catboost.py +++ b/tests/modules/scoring/test_catboost.py @@ -5,9 +5,12 @@ import numpy as np import pytest +from autointent.configs import SentenceTransformerEmbeddingConfig from autointent.context.data_handler import DataHandler from autointent.modules import CatBoostScorer +_embedder_config = SentenceTransformerEmbeddingConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_catboost_scorer_dump_load(dataset): """Test that CatBoostScorer can be saved and loaded while preserving predictions.""" @@ -53,7 +56,7 @@ def test_catboost_prediction_multilabel(dataset): data_handler = DataHandler(dataset.to_multilabel()) scorer = CatBoostScorer( - embedder_config="prajjwal1/bert-tiny", + embedder_config=_embedder_config, iterations=50, learning_rate=0.05, depth=6, @@ -97,7 +100,7 @@ def test_catboost_features_types(dataset, features_type, use_embedding_features) data_handler = DataHandler(dataset) scorer = CatBoostScorer( - embedder_config="prajjwal1/bert-tiny", + embedder_config=_embedder_config, iterations=50, learning_rate=0.05, depth=6, diff --git a/tests/modules/scoring/test_gcn_scorer.py b/tests/modules/scoring/test_gcn_scorer.py index e51d7c547..07a302a1c 100644 --- a/tests/modules/scoring/test_gcn_scorer.py +++ b/tests/modules/scoring/test_gcn_scorer.py @@ -3,8 +3,11 @@ import torch from autointent import Dataset +from autointent.configs import SentenceTransformerEmbeddingConfig from autointent.modules.scoring import GCNScorer +_embedder_config = SentenceTransformerEmbeddingConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + @pytest.fixture def multilabel_dataset(): @@ -44,7 +47,7 @@ def multiclass_dataset(): def test_gcn_scorer_multilabel(multilabel_dataset): torch.manual_seed(42) - scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42) + scorer = GCNScorer(embedder_config=_embedder_config, num_train_epochs=1, batch_size=2, seed=42) train_utterances = multilabel_dataset["train"]["utterance"] train_labels = multilabel_dataset["train"]["label"] descriptions = [intent.name for intent in multilabel_dataset.intents] @@ -59,7 +62,7 @@ def test_gcn_scorer_multilabel(multilabel_dataset): def test_gcn_scorer_multiclass(multiclass_dataset): torch.manual_seed(42) - scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42) + scorer = GCNScorer(embedder_config=_embedder_config, num_train_epochs=1, batch_size=2, seed=42) train_utterances = multiclass_dataset["train"]["utterance"] train_labels = multiclass_dataset["train"]["label"] descriptions = [intent.name for intent in multiclass_dataset.intents] @@ -75,7 +78,7 @@ def test_gcn_scorer_multiclass(multiclass_dataset): def test_gcn_scorer_dump_load(tmp_path, multilabel_dataset): torch.manual_seed(42) - scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42) + scorer = GCNScorer(embedder_config=_embedder_config, num_train_epochs=1, batch_size=2, seed=42) train_utterances = multilabel_dataset["train"]["utterance"] train_labels = multilabel_dataset["train"]["label"] descriptions = [intent.name for intent in multilabel_dataset.intents] diff --git a/tests/modules/scoring/test_lora.py b/tests/modules/scoring/test_lora.py index f9d2725fd..c4245628b 100644 --- a/tests/modules/scoring/test_lora.py +++ b/tests/modules/scoring/test_lora.py @@ -5,18 +5,19 @@ import numpy as np import pytest +from autointent.configs import HFModelConfig from autointent.context.data_handler import DataHandler from autointent.modules import BERTLoRAScorer +_config = HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_lora_scorer_dump_load(dataset): """Test that BERTLoRAScorer can be saved and loaded while preserving predictions.""" data_handler = DataHandler(dataset) # Create and train scorer - scorer_original = BERTLoRAScorer( - classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8 - ) + scorer_original = BERTLoRAScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) # Test data @@ -59,7 +60,7 @@ def test_lora_prediction(dataset): """Test that the lora model can fit and make predictions.""" data_handler = DataHandler(dataset) - scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BERTLoRAScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) @@ -96,7 +97,7 @@ def test_lora_cache_clearing(dataset): """Test that the lora model properly handles cache clearing.""" data_handler = DataHandler(dataset) - scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BERTLoRAScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) diff --git a/tests/modules/scoring/test_ptuning.py b/tests/modules/scoring/test_ptuning.py index d74fe39bf..d247b6210 100644 --- a/tests/modules/scoring/test_ptuning.py +++ b/tests/modules/scoring/test_ptuning.py @@ -5,16 +5,19 @@ import numpy as np import pytest +from autointent.configs import HFModelConfig from autointent.context.data_handler import DataHandler from autointent.modules import PTuningScorer +_config = HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_ptuning_scorer_dump_load(dataset): """Test that PTuningScorer can be saved and loaded while preserving predictions.""" data_handler = DataHandler(dataset) scorer_original = PTuningScorer( - classification_model_config="prajjwal1/bert-tiny", + classification_model_config=_config, num_train_epochs=1, batch_size=8, num_virtual_tokens=10, @@ -54,7 +57,7 @@ def test_ptuning_prediction(dataset): data_handler = DataHandler(dataset) scorer = PTuningScorer( - classification_model_config="prajjwal1/bert-tiny", + classification_model_config=_config, num_train_epochs=1, batch_size=8, num_virtual_tokens=10, @@ -93,7 +96,7 @@ def test_ptuning_cache_clearing(dataset): data_handler = DataHandler(dataset) scorer = PTuningScorer( - classification_model_config="prajjwal1/bert-tiny", + classification_model_config=_config, num_train_epochs=1, batch_size=8, num_virtual_tokens=20,