From 847fb1613048c0863bcf408c603f470751b7a33f Mon Sep 17 00:00:00 2001 From: voorhs Date: Wed, 11 Mar 2026 22:09:17 +0300 Subject: [PATCH 01/21] [ai] first version --- .../context/data_handler/__init__.py | 15 +- .../context/data_handler/_stratification.py | 252 +++++++++++++----- tests/data/test_check_split_readiness.py | 232 ++++++++++++++++ 3 files changed, 437 insertions(+), 62 deletions(-) create mode 100644 tests/data/test_check_split_readiness.py diff --git a/src/autointent/context/data_handler/__init__.py b/src/autointent/context/data_handler/__init__.py index a2e91358f..24a3d3b55 100644 --- a/src/autointent/context/data_handler/__init__.py +++ b/src/autointent/context/data_handler/__init__.py @@ -1,4 +1,15 @@ from ._data_handler import DataHandler -from ._stratification import StratifiedSplitter, split_dataset +from ._stratification import ( + SplitReadinessResult, + StratifiedSplitter, + check_split_readiness, + split_dataset, +) -__all__ = ["DataHandler", "StratifiedSplitter", "split_dataset"] +__all__ = [ + "DataHandler", + "SplitReadinessResult", + "StratifiedSplitter", + "check_split_readiness", + "split_dataset", +] diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 1482383ab..411bad931 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -7,6 +7,8 @@ from __future__ import annotations import logging +from collections import Counter +from dataclasses import dataclass from typing import TYPE_CHECKING import numpy as np @@ -16,7 +18,7 @@ from transformers import set_seed if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from datasets import Dataset as HFDataset from numpy import typing as npt @@ -27,6 +29,37 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class StratifyInputs: + """Inputs for stratified splitting: the effective dataset and post-split hook. + + Used internally so that the same OOS/mapping logic feeds both splitting and + readiness checks. + """ + + dataset: HFDataset + multilabel: bool + test_size: float + post_split_fn: Callable[[HFDataset, HFDataset], tuple[HFDataset, HFDataset]] + + +@dataclass(frozen=True) +class SplitReadinessResult: + """Result of checking whether a dataset can be stratified split. + + Attributes: + ready: True if stratification can be performed (enough samples per class). + underpopulated_classes: List of (label, count) for classes below the minimum. + min_samples_per_class_required: Minimum samples per class used for the check. + reason: Human-readable reason when not ready (e.g. OOS not configured). + """ + + ready: bool + underpopulated_classes: list[tuple[int | str | None, int]] + min_samples_per_class_required: int + reason: str | None + + class StratifiedSplitter: """A class for stratified splitting of datasets. @@ -77,7 +110,7 @@ def __call__( Raises: ValueError: If OOS samples are present but allow_oos_in_train is not specified. """ - if not self._has_oos_samples(dataset): + if not self.has_oos_samples(dataset): train, test = self._split_without_oos(dataset, multilabel, self.test_size) if self.is_few_shot: train, test = create_few_shot_split( @@ -94,8 +127,9 @@ def __call__( "you need to set the parameter allow_oos_in_train." ) raise ValueError(msg) - splitter = self._split_allow_oos_in_train if allow_oos_in_train else self._split_disallow_oos_in_train - train, test = splitter(dataset, multilabel) + inputs = self.get_stratify_inputs(dataset, multilabel, allow_oos_in_train) + train, test = self._split_without_oos(inputs.dataset, inputs.multilabel, inputs.test_size) + train, test = inputs.post_split_fn(train, test) if self.is_few_shot: train, test = create_few_shot_split( train, @@ -106,7 +140,7 @@ def __call__( ) return train, test - def _has_oos_samples(self, dataset: HFDataset) -> bool: + def has_oos_samples(self, dataset: HFDataset) -> bool: """Check if the dataset contains out-of-scope samples. Args: @@ -118,6 +152,79 @@ def _has_oos_samples(self, dataset: HFDataset) -> bool: oos_samples = dataset.filter(lambda sample: sample[self.label_feature] is None) return len(oos_samples) > 0 + def get_stratify_inputs(self, dataset: HFDataset, multilabel: bool, allow_oos_in_train: bool) -> StratifyInputs: + """Return the effective dataset and post-split hook for stratification. + + Single source of truth for OOS handling: both splitting and readiness + checks use this so logic is not duplicated. + + Args: + dataset: The input dataset (may contain OOS). + multilabel: Whether the dataset is multi-label. + allow_oos_in_train: Whether OOS samples are allowed in the train split. + + Returns: + StratifyInputs with the dataset to stratify on and a post_split_fn. + """ + if not self.has_oos_samples(dataset): + return StratifyInputs( + dataset=dataset, + multilabel=multilabel, + test_size=self.test_size, + post_split_fn=lambda train_ds, test_ds: (train_ds, test_ds), + ) + if allow_oos_in_train: + if multilabel: + in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) + n_classes = len(in_domain_sample[self.label_feature]) + mapped_dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) + + def unmap_oos_multilabel(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + return ( + train_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), + test_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), + ) + + return StratifyInputs( + dataset=mapped_dataset, + multilabel=False, + test_size=self.test_size, + post_split_fn=unmap_oos_multilabel, + ) + oos_class_id = len(dataset.unique(self.label_feature)) - 1 + mapped_dataset = dataset.map(self._map_label, fn_kwargs={"old": None, "new": oos_class_id}) + + def unmap_oos_multiclass(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + return ( + train_ds.map( + self._map_label, + fn_kwargs={"old": oos_class_id, "new": None}, + ), + test_ds.map( + self._map_label, + fn_kwargs={"old": oos_class_id, "new": None}, + ), + ) + + return StratifyInputs( + dataset=mapped_dataset, + multilabel=False, + test_size=self.test_size, + post_split_fn=unmap_oos_multiclass, + ) + in_domain_dataset, out_of_domain_dataset = self._separate_oos(dataset) + adjusted_test_size = self._get_adjusted_test_size(len(dataset), len(out_of_domain_dataset)) + + def concat_oos_to_test(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + return (train_ds, concatenate_datasets([test_ds, out_of_domain_dataset])) + + return StratifyInputs( + dataset=in_domain_dataset, + multilabel=multilabel, + test_size=adjusted_test_size, + post_split_fn=concat_oos_to_test, + ) + def _split_without_oos(self, dataset: HFDataset, multilabel: bool, test_size: float) -> tuple[HFDataset, HFDataset]: """Split dataset that doesn't contain OOS samples. @@ -170,42 +277,6 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np ) return next(splitter.split(np.arange(len(dataset)), np.array(dataset[self.label_feature]))) - def _split_allow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> tuple[HFDataset, HFDataset]: - """Proportionally distribute OOS samples between two splits. - - Internally creates a dataset copy with some integer assigned as OOS class id. - With OOS samples treated as a separate class we obtain proportional distribution - of them between two splits. - - Args: - dataset: Dataset to split. - multilabel: Whether the dataset is multi-label. - - Returns: - A tuple containing training and testing datasets. - """ - # add oos as a class - if multilabel: - in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) - n_classes = len(in_domain_sample[self.label_feature]) - dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) - else: - oos_class_id = len(dataset.unique(self.label_feature)) - 1 - dataset = dataset.map(self._map_label, fn_kwargs={"old": None, "new": oos_class_id}) - - # perform stratified splitting - train, test = self._split_without_oos(dataset, multilabel=False, test_size=self.test_size) - - # remove oos as a class - if multilabel: - train = train.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}) - test = test.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}) - else: - train = train.map(self._map_label, fn_kwargs={"old": oos_class_id, "new": None}) - test = test.map(self._map_label, fn_kwargs={"old": oos_class_id, "new": None}) - - return train, test - def _map_label( self, sample: dict[str, str | LabelType], old: LabelType, new: LabelType ) -> dict[str, str | LabelType]: @@ -253,25 +324,6 @@ def _remove_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) sample[self.label_feature] = None # type: ignore[assignment] return sample - def _split_disallow_oos_in_train(self, dataset: HFDataset, multilabel: bool) -> tuple[HFDataset, HFDataset]: - """Move all OOS samples to test split. - - This method preserves the defined test_size proportion so you won't get unexpectedly - large test set even you have lots of OOS samples. - - Args: - dataset: Dataset to split. - multilabel: Whether the dataset is multi-label. - - Returns: - A tuple containing training and testing datasets. - """ - in_domain_dataset, out_of_domain_dataset = self._separate_oos(dataset) - adjusted_test_size = self._get_adjusted_test_size(len(dataset), len(out_of_domain_dataset)) - train, test = self._split_without_oos(in_domain_dataset, multilabel, adjusted_test_size) - test = concatenate_datasets([test, out_of_domain_dataset]) - return train, test - def _separate_oos(self, dataset: HFDataset) -> tuple[HFDataset, HFDataset]: """Separate OOS samples from in-domain samples. @@ -315,6 +367,86 @@ def _get_adjusted_test_size(self, n: int, k: int) -> float: return res +def _check_multiclass_counts( + dataset: HFDataset, label_feature: str, min_samples_per_class: int +) -> list[tuple[int | str | None, int]]: + """Return (label, count) for each class with fewer than min_samples_per_class samples.""" + labels = dataset[label_feature] + counts = Counter(labels) + return [(label, count) for label, count in counts.items() if count < min_samples_per_class] + + +def check_split_readiness( + dataset: Dataset, + split: str, + test_size: float, + min_samples_per_class: int = 2, + allow_oos_in_train: bool | None = None, +) -> SplitReadinessResult: + """Check whether the dataset has enough samples per class for stratified splitting. + + Uses the same OOS and stratification logic as :func:`split_dataset`, so downstream + code can call this before creating a :class:`DataHandler` or calling :func:`split_dataset` + and handle underpopulated classes (e.g. skip phase, log, or fail with a clear message). + + Args: + dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). + split: The split name to check (e.g. ``Split.TRAIN``). + test_size: Proportion used for the test split (must match the value used when splitting). + min_samples_per_class: Minimum number of samples per class required for stratification. + Default 2 matches sklearn's requirement for a 2-way stratified split. + allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples + and this is not set, the function returns ``ready=False`` with a reason. + + Returns: + SplitReadinessResult with ``ready``, ``underpopulated_classes``, and optional ``reason``. + """ + if split not in dataset: + return SplitReadinessResult( + ready=False, + underpopulated_classes=[], + min_samples_per_class_required=min_samples_per_class, + reason=f"Dataset has no split '{split}'.", + ) + hf_split = dataset[split] + splitter = StratifiedSplitter( + test_size=test_size, + label_feature=dataset.label_feature, + random_seed=None, + ) + if splitter.has_oos_samples(hf_split) and allow_oos_in_train is None: + return SplitReadinessResult( + ready=False, + underpopulated_classes=[], + min_samples_per_class_required=min_samples_per_class, + reason="OOS samples present; set allow_oos_in_train to check readiness.", + ) + allow = allow_oos_in_train if allow_oos_in_train is not None else False + inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow) + if inputs.multilabel: + # Multilabel stratification uses IterativeStratification; we do not validate it here. + return SplitReadinessResult( + ready=True, + underpopulated_classes=[], + min_samples_per_class_required=min_samples_per_class, + reason=None, + ) + underpopulated = _check_multiclass_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) + ready = len(underpopulated) == 0 + reason = None + if not ready: + parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] + reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format( + min_samples_per_class, "; ".join(parts) + ) + return SplitReadinessResult( + ready=ready, + underpopulated_classes=underpopulated, + min_samples_per_class_required=min_samples_per_class, + reason=reason, + ) + + def split_dataset( dataset: Dataset, split: str, diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py new file mode 100644 index 000000000..03d4162b9 --- /dev/null +++ b/tests/data/test_check_split_readiness.py @@ -0,0 +1,232 @@ +"""Unit tests for check_split_readiness and SplitReadinessResult.""" + +import pytest + +from autointent import Dataset +from autointent.context.data_handler import ( + SplitReadinessResult, + check_split_readiness, +) +from autointent.custom_types import Split + + +@pytest.fixture +def dataset_enough_samples(): + """Multiclass dataset with ≥2 samples per class (no OOS). Ready for stratification.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + {"utterance": "c1", "label": 2}, + {"utterance": "c2", "label": 2}, + ], + "test": [ + {"utterance": "t1", "label": 0}, + {"utterance": "t2", "label": 1}, + {"utterance": "t3", "label": 2}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + +@pytest.fixture +def dataset_underpopulated(): + """Multiclass dataset with one class having only 1 sample. Not ready for stratification.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + ], + "test": [ + {"utterance": "t1", "label": 0}, + {"utterance": "t2", "label": 1}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + +@pytest.fixture +def dataset_two_classes_barely_enough(): + """Two classes with exactly 2 samples each. Ready for default min_samples_per_class=2.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + ], + "test": [ + {"utterance": "t1", "label": 0}, + {"utterance": "t2", "label": 1}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + +def test_check_split_readiness_ready_when_enough_samples(dataset_enough_samples): + """When every class has ≥ min_samples_per_class, result is ready.""" + result = check_split_readiness( + dataset_enough_samples, + split=Split.TRAIN, + test_size=0.3, + allow_oos_in_train=False, + ) + assert isinstance(result, SplitReadinessResult) + assert result.ready is True + assert result.underpopulated_classes == [] + assert result.min_samples_per_class_required == 2 + assert result.reason is None + + +def test_check_split_readiness_not_ready_underpopulated(dataset_underpopulated): + """When at least one class has fewer than min samples, result is not ready.""" + result = check_split_readiness( + dataset_underpopulated, + split=Split.TRAIN, + test_size=0.3, + allow_oos_in_train=False, + ) + assert result.ready is False + assert result.min_samples_per_class_required == 2 + assert len(result.underpopulated_classes) == 1 + label, count = result.underpopulated_classes[0] + assert label == 1 + assert count == 1 + assert result.reason is not None + assert "class 1" in result.reason or "1" in result.reason + assert "1 (need 2)" in result.reason + + +def test_check_split_readiness_missing_split(dataset_enough_samples): + """When split is not in dataset, result is not ready with reason.""" + result = check_split_readiness( + dataset_enough_samples, + split="nonexistent_split", + test_size=0.3, + ) + assert result.ready is False + assert result.underpopulated_classes == [] + assert "nonexistent_split" in result.reason + + +def test_check_split_readiness_oos_allow_none(dataset_unsplitted): + """When dataset has OOS and allow_oos_in_train is None, result is not ready.""" + result = check_split_readiness( + dataset_unsplitted, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=None, + ) + assert result.ready is False + assert "OOS" in result.reason or "allow_oos_in_train" in result.reason + + +def test_check_split_readiness_oos_allow_false_enough_in_domain(dataset_unsplitted): + """With OOS and allow_oos_in_train=False, in-domain classes are checked; clinc subset has enough.""" + result = check_split_readiness( + dataset_unsplitted, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=False, + ) + assert result.ready is True + assert result.underpopulated_classes == [] + assert result.reason is None + + +def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_barely_enough): + """Custom min_samples_per_class is respected.""" + result = check_split_readiness( + dataset_two_classes_barely_enough, + split=Split.TRAIN, + test_size=0.3, + min_samples_per_class=2, + allow_oos_in_train=False, + ) + assert result.ready is True + + result_strict = check_split_readiness( + dataset_two_classes_barely_enough, + split=Split.TRAIN, + test_size=0.3, + min_samples_per_class=3, + allow_oos_in_train=False, + ) + assert result_strict.ready is False + assert len(result_strict.underpopulated_classes) == 2 + assert result_strict.min_samples_per_class_required == 3 + + +def test_check_split_readiness_multilabel_returns_ready(dataset_unsplitted): + """Multilabel datasets return ready=True (multilabel stratification is not validated).""" + dataset = dataset_unsplitted.to_multilabel() + result = check_split_readiness( + dataset, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=False, + ) + assert result.ready is True + assert result.underpopulated_classes == [] + + +def test_check_split_readiness_consistent_with_split_dataset(dataset_enough_samples): + """When check_split_readiness says ready, split_dataset does not raise.""" + result = check_split_readiness( + dataset_enough_samples, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=False, + ) + assert result.ready is True + from autointent.context.data_handler import split_dataset + + train, test = split_dataset( + dataset_enough_samples, + split=Split.TRAIN, + test_size=0.5, + random_seed=42, + allow_oos_in_train=False, + ) + assert len(train) > 0 + assert len(test) > 0 + + +def test_check_split_readiness_underpopulated_implies_split_raises(dataset_underpopulated): + """When check_split_readiness says not ready (underpopulated), split_dataset raises.""" + result = check_split_readiness( + dataset_underpopulated, + split=Split.TRAIN, + test_size=0.3, + allow_oos_in_train=False, + ) + assert result.ready is False + from autointent.context.data_handler import split_dataset + + with pytest.raises(ValueError, match=r"least populated|too few"): + split_dataset( + dataset_underpopulated, + split=Split.TRAIN, + test_size=0.3, + random_seed=42, + allow_oos_in_train=False, + ) From e5a8ac9bef93fcbff49d9f4001c828e60f46a637 Mon Sep 17 00:00:00 2001 From: voorhs Date: Wed, 11 Mar 2026 22:29:37 +0300 Subject: [PATCH 02/21] decompose `get_stratify_inputs` a little bit --- .../context/data_handler/_stratification.py | 106 ++++++++++++------ 1 file changed, 72 insertions(+), 34 deletions(-) diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 411bad931..d01005f77 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -121,12 +121,6 @@ def __call__( examples_per_label=self.examples_per_label, ) return train, test - if allow_oos_in_train is None: - msg = ( - "Error while splitting dataset. It contains OOS samples, " - "you need to set the parameter allow_oos_in_train." - ) - raise ValueError(msg) inputs = self.get_stratify_inputs(dataset, multilabel, allow_oos_in_train) train, test = self._split_without_oos(inputs.dataset, inputs.multilabel, inputs.test_size) train, test = inputs.post_split_fn(train, test) @@ -152,7 +146,9 @@ def has_oos_samples(self, dataset: HFDataset) -> bool: oos_samples = dataset.filter(lambda sample: sample[self.label_feature] is None) return len(oos_samples) > 0 - def get_stratify_inputs(self, dataset: HFDataset, multilabel: bool, allow_oos_in_train: bool) -> StratifyInputs: + def get_stratify_inputs( + self, dataset: HFDataset, multilabel: bool, allow_oos_in_train: bool | None + ) -> StratifyInputs: """Return the effective dataset and post-split hook for stratification. Single source of truth for OOS handling: both splitting and readiness @@ -162,10 +158,20 @@ def get_stratify_inputs(self, dataset: HFDataset, multilabel: bool, allow_oos_in dataset: The input dataset (may contain OOS). multilabel: Whether the dataset is multi-label. allow_oos_in_train: Whether OOS samples are allowed in the train split. + Must be set when the dataset contains OOS samples. Returns: StratifyInputs with the dataset to stratify on and a post_split_fn. + + Raises: + ValueError: If the dataset contains OOS samples and allow_oos_in_train is None. """ + if self.has_oos_samples(dataset) and allow_oos_in_train is None: + msg = ( + "Error while splitting dataset. It contains OOS samples, " + "you need to set the parameter allow_oos_in_train." + ) + raise ValueError(msg) if not self.has_oos_samples(dataset): return StratifyInputs( dataset=dataset, @@ -174,35 +180,32 @@ def get_stratify_inputs(self, dataset: HFDataset, multilabel: bool, allow_oos_in post_split_fn=lambda train_ds, test_ds: (train_ds, test_ds), ) if allow_oos_in_train: - if multilabel: - in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) - n_classes = len(in_domain_sample[self.label_feature]) - mapped_dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) - - def unmap_oos_multilabel(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: - return ( - train_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), - test_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), - ) - - return StratifyInputs( - dataset=mapped_dataset, - multilabel=False, - test_size=self.test_size, - post_split_fn=unmap_oos_multilabel, - ) - oos_class_id = len(dataset.unique(self.label_feature)) - 1 - mapped_dataset = dataset.map(self._map_label, fn_kwargs={"old": None, "new": oos_class_id}) + return self._stratify_inputs_allow_oos(dataset, multilabel) + return self._stratify_inputs_disallow_oos(dataset, multilabel) + + def _stratify_inputs_allow_oos(self, dataset: HFDataset, multilabel: bool) -> StratifyInputs: + """Return stratify inputs when OOS samples are allowed in the train split. - def unmap_oos_multiclass(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + OOS is mapped to a class so it is stratified; post_split_fn unmaps it. + """ + if multilabel: + in_domain_sample = next( + sample for sample in dataset if sample[self.label_feature] is not None + ) + n_classes = len(in_domain_sample[self.label_feature]) + mapped_dataset = dataset.map( + self._add_oos_label, fn_kwargs={"n_classes": n_classes} + ) + + def unmap_oos_multilabel( + train_ds: HFDataset, test_ds: HFDataset + ) -> tuple[HFDataset, HFDataset]: return ( train_ds.map( - self._map_label, - fn_kwargs={"old": oos_class_id, "new": None}, + self._remove_oos_label, fn_kwargs={"n_classes": n_classes} ), test_ds.map( - self._map_label, - fn_kwargs={"old": oos_class_id, "new": None}, + self._remove_oos_label, fn_kwargs={"n_classes": n_classes} ), ) @@ -210,12 +213,47 @@ def unmap_oos_multiclass(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDat dataset=mapped_dataset, multilabel=False, test_size=self.test_size, - post_split_fn=unmap_oos_multiclass, + post_split_fn=unmap_oos_multilabel, + ) + oos_class_id = len(dataset.unique(self.label_feature)) - 1 + mapped_dataset = dataset.map( + self._map_label, fn_kwargs={"old": None, "new": oos_class_id} + ) + + def unmap_oos_multiclass( + train_ds: HFDataset, test_ds: HFDataset + ) -> tuple[HFDataset, HFDataset]: + return ( + train_ds.map( + self._map_label, + fn_kwargs={"old": oos_class_id, "new": None}, + ), + test_ds.map( + self._map_label, + fn_kwargs={"old": oos_class_id, "new": None}, + ), ) + + return StratifyInputs( + dataset=mapped_dataset, + multilabel=False, + test_size=self.test_size, + post_split_fn=unmap_oos_multiclass, + ) + + def _stratify_inputs_disallow_oos(self, dataset: HFDataset, multilabel: bool) -> StratifyInputs: + """Return stratify inputs when OOS samples are not allowed in the train split. + + Only in-domain data is stratified; post_split_fn appends OOS to the test set. + """ in_domain_dataset, out_of_domain_dataset = self._separate_oos(dataset) - adjusted_test_size = self._get_adjusted_test_size(len(dataset), len(out_of_domain_dataset)) + adjusted_test_size = self._get_adjusted_test_size( + len(dataset), len(out_of_domain_dataset) + ) - def concat_oos_to_test(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: + def concat_oos_to_test( + train_ds: HFDataset, test_ds: HFDataset + ) -> tuple[HFDataset, HFDataset]: return (train_ds, concatenate_datasets([test_ds, out_of_domain_dataset])) return StratifyInputs( From b6a8ad6ea1eb273508b19b1b1a81c0cef3861199 Mon Sep 17 00:00:00 2001 From: voorhs Date: Wed, 11 Mar 2026 22:33:10 +0300 Subject: [PATCH 03/21] remove redundant check in `get_stratify_inputs` --- .../context/data_handler/_stratification.py | 47 ++++--------------- 1 file changed, 9 insertions(+), 38 deletions(-) diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index d01005f77..8575dcf3f 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -110,17 +110,6 @@ def __call__( Raises: ValueError: If OOS samples are present but allow_oos_in_train is not specified. """ - if not self.has_oos_samples(dataset): - train, test = self._split_without_oos(dataset, multilabel, self.test_size) - if self.is_few_shot: - train, test = create_few_shot_split( - train, - test, - multilabel=multilabel, - label_column=self.label_feature, - examples_per_label=self.examples_per_label, - ) - return train, test inputs = self.get_stratify_inputs(dataset, multilabel, allow_oos_in_train) train, test = self._split_without_oos(inputs.dataset, inputs.multilabel, inputs.test_size) train, test = inputs.post_split_fn(train, test) @@ -189,24 +178,14 @@ def _stratify_inputs_allow_oos(self, dataset: HFDataset, multilabel: bool) -> St OOS is mapped to a class so it is stratified; post_split_fn unmaps it. """ if multilabel: - in_domain_sample = next( - sample for sample in dataset if sample[self.label_feature] is not None - ) + in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) n_classes = len(in_domain_sample[self.label_feature]) - mapped_dataset = dataset.map( - self._add_oos_label, fn_kwargs={"n_classes": n_classes} - ) + mapped_dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) - def unmap_oos_multilabel( - train_ds: HFDataset, test_ds: HFDataset - ) -> tuple[HFDataset, HFDataset]: + def unmap_oos_multilabel(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: return ( - train_ds.map( - self._remove_oos_label, fn_kwargs={"n_classes": n_classes} - ), - test_ds.map( - self._remove_oos_label, fn_kwargs={"n_classes": n_classes} - ), + train_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), + test_ds.map(self._remove_oos_label, fn_kwargs={"n_classes": n_classes}), ) return StratifyInputs( @@ -216,13 +195,9 @@ def unmap_oos_multilabel( post_split_fn=unmap_oos_multilabel, ) oos_class_id = len(dataset.unique(self.label_feature)) - 1 - mapped_dataset = dataset.map( - self._map_label, fn_kwargs={"old": None, "new": oos_class_id} - ) + mapped_dataset = dataset.map(self._map_label, fn_kwargs={"old": None, "new": oos_class_id}) - def unmap_oos_multiclass( - train_ds: HFDataset, test_ds: HFDataset - ) -> tuple[HFDataset, HFDataset]: + def unmap_oos_multiclass(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: return ( train_ds.map( self._map_label, @@ -247,13 +222,9 @@ def _stratify_inputs_disallow_oos(self, dataset: HFDataset, multilabel: bool) -> Only in-domain data is stratified; post_split_fn appends OOS to the test set. """ in_domain_dataset, out_of_domain_dataset = self._separate_oos(dataset) - adjusted_test_size = self._get_adjusted_test_size( - len(dataset), len(out_of_domain_dataset) - ) + adjusted_test_size = self._get_adjusted_test_size(len(dataset), len(out_of_domain_dataset)) - def concat_oos_to_test( - train_ds: HFDataset, test_ds: HFDataset - ) -> tuple[HFDataset, HFDataset]: + def concat_oos_to_test(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDataset, HFDataset]: return (train_ds, concatenate_datasets([test_ds, out_of_domain_dataset])) return StratifyInputs( From 167782312ec40d30cf8b12697655158337c82e0f Mon Sep 17 00:00:00 2001 From: voorhs Date: Wed, 11 Mar 2026 22:40:41 +0300 Subject: [PATCH 04/21] remove unnecessary check --- src/autointent/context/data_handler/_stratification.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 8575dcf3f..3ac3a3b59 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -423,15 +423,7 @@ def check_split_readiness( label_feature=dataset.label_feature, random_seed=None, ) - if splitter.has_oos_samples(hf_split) and allow_oos_in_train is None: - return SplitReadinessResult( - ready=False, - underpopulated_classes=[], - min_samples_per_class_required=min_samples_per_class, - reason="OOS samples present; set allow_oos_in_train to check readiness.", - ) - allow = allow_oos_in_train if allow_oos_in_train is not None else False - inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow) + inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) if inputs.multilabel: # Multilabel stratification uses IterativeStratification; we do not validate it here. return SplitReadinessResult( From 3059bda8be98f00d82ceed4096137d0933568468 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 18:41:33 +0300 Subject: [PATCH 05/21] add protobuf dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e742242a4..5a089f256 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "aiometer (>=1.0.0,<2.0.0)", "aiofiles (>=24.1.0,<25.0.0)", "threadpoolctl (>=3.0.0,<4.0.0)", + "protobuf>=6,<7" ] [project.optional-dependencies] From a08a0a31e5cf26691929c2f6597256fd61da35f9 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 18:42:27 +0300 Subject: [PATCH 06/21] add `get_default_hfmodel_config` --- src/autointent/_pipeline/_pipeline.py | 3 ++- src/autointent/configs/__init__.py | 2 ++ src/autointent/configs/_transformers.py | 5 +++++ src/autointent/context/_context.py | 7 +++---- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/autointent/_pipeline/_pipeline.py b/src/autointent/_pipeline/_pipeline.py index 3d62f07d1..a764f1d78 100644 --- a/src/autointent/_pipeline/_pipeline.py +++ b/src/autointent/_pipeline/_pipeline.py @@ -22,6 +22,7 @@ LoggingConfig, VectorIndexConfig, get_default_embedder_config, + get_default_hfmodel_config, get_default_vector_index_config, ) from autointent.custom_types import NodeType @@ -64,7 +65,7 @@ def __init__( self.embedder_config = get_default_embedder_config() self.cross_encoder_config = CrossEncoderConfig() self.data_config = DataConfig() - self.transformer_config = HFModelConfig() + self.transformer_config = get_default_hfmodel_config() self.hpo_config = HPOConfig() self.vector_index_config = get_default_vector_index_config() elif not isinstance(nodes[0], InferenceNode): diff --git a/src/autointent/configs/__init__.py b/src/autointent/configs/__init__.py index b0495c372..01f7f3106 100644 --- a/src/autointent/configs/__init__.py +++ b/src/autointent/configs/__init__.py @@ -17,6 +17,7 @@ EmbedderFineTuningConfig, HFModelConfig, TokenizerConfig, + get_default_hfmodel_config, ) from ._vector_index import FaissConfig, OpenSearchConfig, VectorIndexConfig, get_default_vector_index_config @@ -40,6 +41,7 @@ "VectorIndexConfig", "VocabConfig", "get_default_embedder_config", + "get_default_hfmodel_config", "get_default_vector_index_config", "initialize_embedder_config", ] diff --git a/src/autointent/configs/_transformers.py b/src/autointent/configs/_transformers.py index 9bdc9597d..90bed58e8 100644 --- a/src/autointent/configs/_transformers.py +++ b/src/autointent/configs/_transformers.py @@ -53,6 +53,7 @@ class HFModelConfig(BaseModel): fp16: bool = Field(False, description="Whether to use mixed precision training (not all devices support this).") tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig) trust_remote_code: bool = Field(False, description="Whether to trust the remote code when loading the model.") + revision: str | None = Field(None, description="Revision from HF repo") @classmethod def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self: @@ -75,6 +76,10 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> return cls(**values) +def get_default_hfmodel_config() -> HFModelConfig: + return HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + + class CrossEncoderConfig(HFModelConfig): model_name: str = Field("cross-encoder/ms-marco-MiniLM-L6-v2", description="Name of the hugging face model.") train_head: bool = Field( diff --git a/src/autointent/context/_context.py b/src/autointent/context/_context.py index 7be00ac22..7410ec4b3 100644 --- a/src/autointent/context/_context.py +++ b/src/autointent/context/_context.py @@ -16,6 +16,7 @@ HPOConfig, LoggingConfig, VectorIndexConfig, + get_default_hfmodel_config, ) from .data_handler import DataHandler @@ -25,9 +26,7 @@ from pathlib import Path from autointent import Dataset - from autointent.configs import ( - DataConfig, - ) + from autointent.configs import DataConfig class Context: @@ -202,4 +201,4 @@ def resolve_transformer(self) -> HFModelConfig: """ if hasattr(self, "transformer_config"): return self.transformer_config - return HFModelConfig() + return get_default_hfmodel_config() From 84465bbb99518fbea71ce6fbda2bf44f730f88c3 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 18:45:18 +0300 Subject: [PATCH 07/21] enable cache in uv ci step --- .github/workflows/reusable-test.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/reusable-test.yaml b/.github/workflows/reusable-test.yaml index 502a0d35a..80b2a4ff4 100644 --- a/.github/workflows/reusable-test.yaml +++ b/.github/workflows/reusable-test.yaml @@ -35,6 +35,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: version: "0.10.0" + enable-cache: "true" - name: Install dependencies for Python ${{ matrix.python-version }} run: | From b3ad53285548758240bbf9b7556b6fe3d75bd2d9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 12 Mar 2026 15:45:49 +0000 Subject: [PATCH 08/21] Update optimizer_config.schema.json --- docs/optimizer_config.schema.json | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index d19eda75b..b44883200 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -50,6 +50,19 @@ "title": "Trust Remote Code", "type": "boolean" }, + "revision": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Revision from HF repo", + "title": "Revision" + }, "train_head": { "default": false, "description": "Whether to train the head of the model. If False, LogReg will be trained.", @@ -262,6 +275,19 @@ "description": "Whether to trust the remote code when loading the model.", "title": "Trust Remote Code", "type": "boolean" + }, + "revision": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Revision from HF repo", + "title": "Revision" } }, "title": "HFModelConfig", @@ -515,6 +541,7 @@ "truncation": true }, "trust_remote_code": false, + "revision": null, "train_head": false } }, @@ -531,7 +558,8 @@ "padding": true, "truncation": true }, - "trust_remote_code": false + "trust_remote_code": false, + "revision": null } }, "hpo_config": { From 55f90bc6fb3719cbcca17c52f8fb1cf5254379e5 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 18:51:54 +0300 Subject: [PATCH 09/21] add tiktoken dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5a089f256..cef1f338b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "datasets (>=3.2.0,<5.0.0)", "xxhash (>=3.5.0,<4.0.0)", "python-dotenv (>=1.0.1,<2.0.0)", - "transformers[torch] (>=4.49.0,<6.0.0)", + "transformers[torch,tiktoken] (>=4.49.0,<6.0.0)", "peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)", "catboost (>=1.2.8,<2.0.0)", "aiometer (>=1.0.0,<2.0.0)", From 65a8cbe249d01c49581d940105e92436748343f0 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 18:52:01 +0300 Subject: [PATCH 10/21] fix typing --- src/autointent/modules/scoring/_bert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/autointent/modules/scoring/_bert.py b/src/autointent/modules/scoring/_bert.py index 4e80e06ad..2150bc01a 100644 --- a/src/autointent/modules/scoring/_bert.py +++ b/src/autointent/modules/scoring/_bert.py @@ -200,8 +200,8 @@ def _train(self, tokenized_dataset: DatasetDict) -> None: callbacks=self._get_trainer_callbacks(), ) if not self.print_progress: - trainer.remove_callback(PrinterCallback) # type: ignore[no-untyped-call] - trainer.remove_callback(ProgressCallback) # type: ignore[no-untyped-call] + trainer.remove_callback(PrinterCallback) + trainer.remove_callback(ProgressCallback) trainer.train() From 9354bf8326c074bf281fe9dfa8fe4aa22a719e11 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 18:58:36 +0300 Subject: [PATCH 11/21] apply hfmodel config properly --- src/autointent/_optimization_config.py | 3 ++- src/autointent/_wrappers/embedder/sentence_transformers.py | 1 + src/autointent/modules/scoring/_bert.py | 5 ++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/autointent/_optimization_config.py b/src/autointent/_optimization_config.py index b78189b29..25b591be1 100644 --- a/src/autointent/_optimization_config.py +++ b/src/autointent/_optimization_config.py @@ -11,6 +11,7 @@ HFModelConfig, HPOConfig, LoggingConfig, + get_default_hfmodel_config, initialize_embedder_config, ) @@ -40,7 +41,7 @@ def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401 cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig() - transformer_config: HFModelConfig = HFModelConfig() + transformer_config: HFModelConfig = get_default_hfmodel_config() hpo_config: HPOConfig = HPOConfig() diff --git a/src/autointent/_wrappers/embedder/sentence_transformers.py b/src/autointent/_wrappers/embedder/sentence_transformers.py index 0948ea406..e19833e52 100644 --- a/src/autointent/_wrappers/embedder/sentence_transformers.py +++ b/src/autointent/_wrappers/embedder/sentence_transformers.py @@ -84,6 +84,7 @@ def _load_model(self) -> SentenceTransformer: prompts=self.config.get_prompt_config(), similarity_fn_name=self.config.similarity_fn_name, trust_remote_code=self.config.trust_remote_code, + revision=self.config.revision, ) self._model = res return self._model diff --git a/src/autointent/modules/scoring/_bert.py b/src/autointent/modules/scoring/_bert.py index 2150bc01a..5a1c9a962 100644 --- a/src/autointent/modules/scoring/_bert.py +++ b/src/autointent/modules/scoring/_bert.py @@ -143,6 +143,7 @@ def _initialize_model(self) -> Any: # noqa: ANN401 return AutoModelForSequenceClassification.from_pretrained( self.classification_model_config.model_name, trust_remote_code=self.classification_model_config.trust_remote_code, + revision=self.classification_model_config.revision, num_labels=self._n_classes, label2id=label2id, id2label=id2label, @@ -156,7 +157,9 @@ def fit( ) -> None: self._validate_task(labels) - self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained( + self.classification_model_config.model_name, revision=self.classification_model_config.revision + ) self._model = self._initialize_model() tokenized_dataset = self._get_tokenized_dataset(utterances, labels) self._train(tokenized_dataset) From 10c69153c62674c6196973bfde38304968995d3b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 12 Mar 2026 15:59:47 +0000 Subject: [PATCH 12/21] Update optimizer_config.schema.json --- docs/optimizer_config.schema.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index b44883200..e45d8cbbf 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -559,7 +559,7 @@ "truncation": true }, "trust_remote_code": false, - "revision": null + "revision": "refs/pr/16" } }, "hpo_config": { From 6a805cf2a0b13a87cdd309fec5a73ca2c9d52139 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 19:08:32 +0300 Subject: [PATCH 13/21] use proper embedder config in gcn and catboost tests --- tests/modules/scoring/test_catboost.py | 7 +++++-- tests/modules/scoring/test_gcn_scorer.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/modules/scoring/test_catboost.py b/tests/modules/scoring/test_catboost.py index 177936980..53a28b995 100644 --- a/tests/modules/scoring/test_catboost.py +++ b/tests/modules/scoring/test_catboost.py @@ -5,9 +5,12 @@ import numpy as np import pytest +from autointent.configs import SentenceTransformerEmbeddingConfig from autointent.context.data_handler import DataHandler from autointent.modules import CatBoostScorer +_embedder_config = SentenceTransformerEmbeddingConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_catboost_scorer_dump_load(dataset): """Test that CatBoostScorer can be saved and loaded while preserving predictions.""" @@ -53,7 +56,7 @@ def test_catboost_prediction_multilabel(dataset): data_handler = DataHandler(dataset.to_multilabel()) scorer = CatBoostScorer( - embedder_config="prajjwal1/bert-tiny", + embedder_config=_embedder_config, iterations=50, learning_rate=0.05, depth=6, @@ -97,7 +100,7 @@ def test_catboost_features_types(dataset, features_type, use_embedding_features) data_handler = DataHandler(dataset) scorer = CatBoostScorer( - embedder_config="prajjwal1/bert-tiny", + embedder_config=_embedder_config, iterations=50, learning_rate=0.05, depth=6, diff --git a/tests/modules/scoring/test_gcn_scorer.py b/tests/modules/scoring/test_gcn_scorer.py index e51d7c547..07a302a1c 100644 --- a/tests/modules/scoring/test_gcn_scorer.py +++ b/tests/modules/scoring/test_gcn_scorer.py @@ -3,8 +3,11 @@ import torch from autointent import Dataset +from autointent.configs import SentenceTransformerEmbeddingConfig from autointent.modules.scoring import GCNScorer +_embedder_config = SentenceTransformerEmbeddingConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + @pytest.fixture def multilabel_dataset(): @@ -44,7 +47,7 @@ def multiclass_dataset(): def test_gcn_scorer_multilabel(multilabel_dataset): torch.manual_seed(42) - scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42) + scorer = GCNScorer(embedder_config=_embedder_config, num_train_epochs=1, batch_size=2, seed=42) train_utterances = multilabel_dataset["train"]["utterance"] train_labels = multilabel_dataset["train"]["label"] descriptions = [intent.name for intent in multilabel_dataset.intents] @@ -59,7 +62,7 @@ def test_gcn_scorer_multilabel(multilabel_dataset): def test_gcn_scorer_multiclass(multiclass_dataset): torch.manual_seed(42) - scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42) + scorer = GCNScorer(embedder_config=_embedder_config, num_train_epochs=1, batch_size=2, seed=42) train_utterances = multiclass_dataset["train"]["utterance"] train_labels = multiclass_dataset["train"]["label"] descriptions = [intent.name for intent in multiclass_dataset.intents] @@ -75,7 +78,7 @@ def test_gcn_scorer_multiclass(multiclass_dataset): def test_gcn_scorer_dump_load(tmp_path, multilabel_dataset): torch.manual_seed(42) - scorer = GCNScorer(embedder_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=2, seed=42) + scorer = GCNScorer(embedder_config=_embedder_config, num_train_epochs=1, batch_size=2, seed=42) train_utterances = multilabel_dataset["train"]["utterance"] train_labels = multilabel_dataset["train"]["label"] descriptions = [intent.name for intent in multilabel_dataset.intents] From 6224086710803ed832b64b619f7613e689c5c28e Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 19:09:23 +0300 Subject: [PATCH 14/21] add sentencepiece dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cef1f338b..4dc8db000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "datasets (>=3.2.0,<5.0.0)", "xxhash (>=3.5.0,<4.0.0)", "python-dotenv (>=1.0.1,<2.0.0)", - "transformers[torch,tiktoken] (>=4.49.0,<6.0.0)", + "transformers[torch,tiktoken,sentencepiece] (>=4.49.0,<6.0.0)", "peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)", "catboost (>=1.2.8,<2.0.0)", "aiometer (>=1.0.0,<2.0.0)", From 50294f6025618e8986460c63b942e016c712b2ab Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 19:18:24 +0300 Subject: [PATCH 15/21] try to use different bert tiny --- src/autointent/configs/_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/autointent/configs/_transformers.py b/src/autointent/configs/_transformers.py index 90bed58e8..2deaabbdd 100644 --- a/src/autointent/configs/_transformers.py +++ b/src/autointent/configs/_transformers.py @@ -77,7 +77,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> def get_default_hfmodel_config() -> HFModelConfig: - return HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + return HFModelConfig(model_name="cointegrated/rubert-tiny2") class CrossEncoderConfig(HFModelConfig): From 6a6dcfe6c832c9440155ffd0cdc95fb81705c248 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 12 Mar 2026 16:19:35 +0000 Subject: [PATCH 16/21] Update optimizer_config.schema.json --- docs/optimizer_config.schema.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index e45d8cbbf..55ec7ca31 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -548,7 +548,7 @@ "transformer_config": { "$ref": "#/$defs/HFModelConfig", "default": { - "model_name": "prajjwal1/bert-tiny", + "model_name": "cointegrated/rubert-tiny2", "batch_size": 32, "device": null, "bf16": false, @@ -559,7 +559,7 @@ "truncation": true }, "trust_remote_code": false, - "revision": "refs/pr/16" + "revision": null } }, "hpo_config": { From cdb9c5a221f0a144bff96596cbde1d21ac6e4576 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 19:30:08 +0300 Subject: [PATCH 17/21] use proper hfmodel config in tests --- src/autointent/configs/_transformers.py | 2 +- tests/modules/scoring/test_bert.py | 9 ++++++--- tests/modules/scoring/test_lora.py | 11 ++++++----- tests/modules/scoring/test_ptuning.py | 9 ++++++--- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/autointent/configs/_transformers.py b/src/autointent/configs/_transformers.py index 2deaabbdd..90bed58e8 100644 --- a/src/autointent/configs/_transformers.py +++ b/src/autointent/configs/_transformers.py @@ -77,7 +77,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> def get_default_hfmodel_config() -> HFModelConfig: - return HFModelConfig(model_name="cointegrated/rubert-tiny2") + return HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") class CrossEncoderConfig(HFModelConfig): diff --git a/tests/modules/scoring/test_bert.py b/tests/modules/scoring/test_bert.py index 86fbc685b..18de9c8ea 100644 --- a/tests/modules/scoring/test_bert.py +++ b/tests/modules/scoring/test_bert.py @@ -5,16 +5,19 @@ import numpy as np import pytest +from autointent.configs import HFModelConfig from autointent.context.data_handler import DataHandler from autointent.modules import BertScorer +_config = HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_bert_scorer_dump_load(dataset): """Test that BertScorer can be saved and loaded while preserving predictions.""" data_handler = DataHandler(dataset) # Create and train scorer - scorer_original = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer_original = BertScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) # Test data @@ -57,7 +60,7 @@ def test_bert_prediction(dataset): """Test that the transformer model can fit and make predictions.""" data_handler = DataHandler(dataset) - scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) @@ -94,7 +97,7 @@ def test_bert_cache_clearing(dataset): """Test that the transformer model properly handles cache clearing.""" data_handler = DataHandler(dataset) - scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) diff --git a/tests/modules/scoring/test_lora.py b/tests/modules/scoring/test_lora.py index f9d2725fd..c4245628b 100644 --- a/tests/modules/scoring/test_lora.py +++ b/tests/modules/scoring/test_lora.py @@ -5,18 +5,19 @@ import numpy as np import pytest +from autointent.configs import HFModelConfig from autointent.context.data_handler import DataHandler from autointent.modules import BERTLoRAScorer +_config = HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_lora_scorer_dump_load(dataset): """Test that BERTLoRAScorer can be saved and loaded while preserving predictions.""" data_handler = DataHandler(dataset) # Create and train scorer - scorer_original = BERTLoRAScorer( - classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8 - ) + scorer_original = BERTLoRAScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) # Test data @@ -59,7 +60,7 @@ def test_lora_prediction(dataset): """Test that the lora model can fit and make predictions.""" data_handler = DataHandler(dataset) - scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BERTLoRAScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) @@ -96,7 +97,7 @@ def test_lora_cache_clearing(dataset): """Test that the lora model properly handles cache clearing.""" data_handler = DataHandler(dataset) - scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BERTLoRAScorer(classification_model_config=_config, num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) diff --git a/tests/modules/scoring/test_ptuning.py b/tests/modules/scoring/test_ptuning.py index d74fe39bf..d247b6210 100644 --- a/tests/modules/scoring/test_ptuning.py +++ b/tests/modules/scoring/test_ptuning.py @@ -5,16 +5,19 @@ import numpy as np import pytest +from autointent.configs import HFModelConfig from autointent.context.data_handler import DataHandler from autointent.modules import PTuningScorer +_config = HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16") + def test_ptuning_scorer_dump_load(dataset): """Test that PTuningScorer can be saved and loaded while preserving predictions.""" data_handler = DataHandler(dataset) scorer_original = PTuningScorer( - classification_model_config="prajjwal1/bert-tiny", + classification_model_config=_config, num_train_epochs=1, batch_size=8, num_virtual_tokens=10, @@ -54,7 +57,7 @@ def test_ptuning_prediction(dataset): data_handler = DataHandler(dataset) scorer = PTuningScorer( - classification_model_config="prajjwal1/bert-tiny", + classification_model_config=_config, num_train_epochs=1, batch_size=8, num_virtual_tokens=10, @@ -93,7 +96,7 @@ def test_ptuning_cache_clearing(dataset): data_handler = DataHandler(dataset) scorer = PTuningScorer( - classification_model_config="prajjwal1/bert-tiny", + classification_model_config=_config, num_train_epochs=1, batch_size=8, num_virtual_tokens=20, From a28886c00ed44aa56ed35b8d21cca53b092674d7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 12 Mar 2026 16:31:24 +0000 Subject: [PATCH 18/21] Update optimizer_config.schema.json --- docs/optimizer_config.schema.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index 55ec7ca31..e45d8cbbf 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -548,7 +548,7 @@ "transformer_config": { "$ref": "#/$defs/HFModelConfig", "default": { - "model_name": "cointegrated/rubert-tiny2", + "model_name": "prajjwal1/bert-tiny", "batch_size": 32, "device": null, "bf16": false, @@ -559,7 +559,7 @@ "truncation": true }, "trust_remote_code": false, - "revision": null + "revision": "refs/pr/16" } }, "hpo_config": { From b53945aef6410a8dc18fa9fc94b3dee6543e01dd Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 19:36:29 +0300 Subject: [PATCH 19/21] add revisions to test configs --- tests/assets/configs/multiclass.yaml | 5 ++++- tests/assets/configs/multilabel.yaml | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index fd63f4396..b42781a04 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -39,6 +39,7 @@ features_type: ["embedding"] embedder_config: - model_name: prajjwal1/bert-tiny + revision: refs/pr/16 - module_name: bert classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 @@ -61,7 +62,9 @@ seed: [0] lora_alpha: [16] - module_name: ptuning - classification_model_config: ["prajjwal1/bert-tiny"] + classification_model_config: + - model_name: "prajjwal1/bert-tiny" + revision: refs/pr/16 num_train_epochs: [1] batch_size: [8, 16] num_virtual_tokens: [10, 20] diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 6dbb0c409..442cc38bf 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -35,6 +35,7 @@ embedder_config: - null - model_name: prajjwal1/bert-tiny + revision: refs/pr/16 - module_name: bert classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 @@ -49,7 +50,9 @@ kernel_sizes: [[3, 4, 5]] num_filters: [100] - module_name: ptuning - classification_model_config: ["prajjwal1/bert-tiny"] + classification_model_config: + - model_name: prajjwal1/bert-tiny + revision: refs/pr/16 num_train_epochs: [1] batch_size: [8] num_virtual_tokens: [10, 20] From 29435465ba1059d1c0ea9e27b1fc0863ab044945 Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 19:49:36 +0300 Subject: [PATCH 20/21] upd unit tests --- tests/callback/test_callback.py | 3 +++ tests/data/test_check_split_readiness.py | 15 +++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index 1a5c646c8..47a3269a2 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -145,6 +145,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "sts_prompt": None, "query_prompt": None, + "revision": None, "passage_prompt": None, "similarity_fn_name": None, "use_cache": True, @@ -176,6 +177,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "sts_prompt": None, "query_prompt": None, + "revision": None, "passage_prompt": None, "similarity_fn_name": None, "use_cache": True, @@ -203,6 +205,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "sts_prompt": None, "query_prompt": None, + "revision": None, "passage_prompt": None, "similarity_fn_name": None, "use_cache": True, diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py index 03d4162b9..962074fc4 100644 --- a/tests/data/test_check_split_readiness.py +++ b/tests/data/test_check_split_readiness.py @@ -130,14 +130,13 @@ def test_check_split_readiness_missing_split(dataset_enough_samples): def test_check_split_readiness_oos_allow_none(dataset_unsplitted): """When dataset has OOS and allow_oos_in_train is None, result is not ready.""" - result = check_split_readiness( - dataset_unsplitted, - split=Split.TRAIN, - test_size=0.5, - allow_oos_in_train=None, - ) - assert result.ready is False - assert "OOS" in result.reason or "allow_oos_in_train" in result.reason + with pytest.raises(ValueError, match="allow_oos_in_train"): + check_split_readiness( + dataset_unsplitted, + split=Split.TRAIN, + test_size=0.5, + allow_oos_in_train=None, + ) def test_check_split_readiness_oos_allow_false_enough_in_domain(dataset_unsplitted): From 9b2d95c73094d7efe5ea89936bf79d79dc331a7b Mon Sep 17 00:00:00 2001 From: voorhs Date: Thu, 12 Mar 2026 20:03:06 +0300 Subject: [PATCH 21/21] remove previously added dependencies --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4dc8db000..e742242a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,13 +43,12 @@ dependencies = [ "datasets (>=3.2.0,<5.0.0)", "xxhash (>=3.5.0,<4.0.0)", "python-dotenv (>=1.0.1,<2.0.0)", - "transformers[torch,tiktoken,sentencepiece] (>=4.49.0,<6.0.0)", + "transformers[torch] (>=4.49.0,<6.0.0)", "peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)", "catboost (>=1.2.8,<2.0.0)", "aiometer (>=1.0.0,<2.0.0)", "aiofiles (>=24.1.0,<25.0.0)", "threadpoolctl (>=3.0.0,<4.0.0)", - "protobuf>=6,<7" ] [project.optional-dependencies]