Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/reusable-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
uses: astral-sh/setup-uv@v6
with:
version: "0.10.0"
enable-cache: "true"

- name: Install dependencies for Python ${{ matrix.python-version }}
run: |
Expand Down
30 changes: 29 additions & 1 deletion docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@
"title": "Trust Remote Code",
"type": "boolean"
},
"revision": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Revision from HF repo",
"title": "Revision"
},
"train_head": {
"default": false,
"description": "Whether to train the head of the model. If False, LogReg will be trained.",
Expand Down Expand Up @@ -262,6 +275,19 @@
"description": "Whether to trust the remote code when loading the model.",
"title": "Trust Remote Code",
"type": "boolean"
},
"revision": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Revision from HF repo",
"title": "Revision"
}
},
"title": "HFModelConfig",
Expand Down Expand Up @@ -515,6 +541,7 @@
"truncation": true
},
"trust_remote_code": false,
"revision": null,
"train_head": false
}
},
Expand All @@ -531,7 +558,8 @@
"padding": true,
"truncation": true
},
"trust_remote_code": false
"trust_remote_code": false,
"revision": "refs/pr/16"
}
},
"hpo_config": {
Expand Down
3 changes: 2 additions & 1 deletion src/autointent/_optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
HFModelConfig,
HPOConfig,
LoggingConfig,
get_default_hfmodel_config,
initialize_embedder_config,
)

Expand Down Expand Up @@ -40,7 +41,7 @@ def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401

cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()

transformer_config: HFModelConfig = HFModelConfig()
transformer_config: HFModelConfig = get_default_hfmodel_config()

hpo_config: HPOConfig = HPOConfig()

Expand Down
3 changes: 2 additions & 1 deletion src/autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LoggingConfig,
VectorIndexConfig,
get_default_embedder_config,
get_default_hfmodel_config,
get_default_vector_index_config,
)
from autointent.custom_types import NodeType
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
self.embedder_config = get_default_embedder_config()
self.cross_encoder_config = CrossEncoderConfig()
self.data_config = DataConfig()
self.transformer_config = HFModelConfig()
self.transformer_config = get_default_hfmodel_config()
self.hpo_config = HPOConfig()
self.vector_index_config = get_default_vector_index_config()
elif not isinstance(nodes[0], InferenceNode):
Expand Down
1 change: 1 addition & 0 deletions src/autointent/_wrappers/embedder/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _load_model(self) -> SentenceTransformer:
prompts=self.config.get_prompt_config(),
similarity_fn_name=self.config.similarity_fn_name,
trust_remote_code=self.config.trust_remote_code,
revision=self.config.revision,
)
self._model = res
return self._model
Expand Down
2 changes: 2 additions & 0 deletions src/autointent/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EmbedderFineTuningConfig,
HFModelConfig,
TokenizerConfig,
get_default_hfmodel_config,
)
from ._vector_index import FaissConfig, OpenSearchConfig, VectorIndexConfig, get_default_vector_index_config

Expand All @@ -40,6 +41,7 @@
"VectorIndexConfig",
"VocabConfig",
"get_default_embedder_config",
"get_default_hfmodel_config",
"get_default_vector_index_config",
"initialize_embedder_config",
]
5 changes: 5 additions & 0 deletions src/autointent/configs/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class HFModelConfig(BaseModel):
fp16: bool = Field(False, description="Whether to use mixed precision training (not all devices support this).")
tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig)
trust_remote_code: bool = Field(False, description="Whether to trust the remote code when loading the model.")
revision: str | None = Field(None, description="Revision from HF repo")

@classmethod
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:
Expand All @@ -75,6 +76,10 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
return cls(**values)


def get_default_hfmodel_config() -> HFModelConfig:
return HFModelConfig(model_name="prajjwal1/bert-tiny", revision="refs/pr/16")


class CrossEncoderConfig(HFModelConfig):
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L6-v2", description="Name of the hugging face model.")
train_head: bool = Field(
Expand Down
7 changes: 3 additions & 4 deletions src/autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HPOConfig,
LoggingConfig,
VectorIndexConfig,
get_default_hfmodel_config,
)

from .data_handler import DataHandler
Expand All @@ -25,9 +26,7 @@
from pathlib import Path

from autointent import Dataset
from autointent.configs import (
DataConfig,
)
from autointent.configs import DataConfig


class Context:
Expand Down Expand Up @@ -202,4 +201,4 @@ def resolve_transformer(self) -> HFModelConfig:
"""
if hasattr(self, "transformer_config"):
return self.transformer_config
return HFModelConfig()
return get_default_hfmodel_config()
15 changes: 13 additions & 2 deletions src/autointent/context/data_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from ._data_handler import DataHandler
from ._stratification import StratifiedSplitter, split_dataset
from ._stratification import (
SplitReadinessResult,
StratifiedSplitter,
check_split_readiness,
split_dataset,
)

__all__ = ["DataHandler", "StratifiedSplitter", "split_dataset"]
__all__ = [
"DataHandler",
"SplitReadinessResult",
"StratifiedSplitter",
"check_split_readiness",
"split_dataset",
]
Loading
Loading