diff --git a/model2vec/train/base.py b/model2vec/train/base.py index 368b4d1..19563b0 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -15,8 +15,8 @@ from torch.nn.utils.rnn import pad_sequence from tqdm import trange -from model2vec import StaticModel from model2vec.inference import StaticModelPipeline +from model2vec.model import PathLike, StaticModel from model2vec.train.dataset import TextDataset from model2vec.train.utils import ( get_probable_pad_token_id, @@ -140,7 +140,7 @@ def _initialize(self) -> None: @classmethod def from_pretrained( cls: type[ModelType], - path: str = "minishlab/potion-base-32m", + path: PathLike = "minishlab/potion-base-32m", *, token: str | None = None, **kwargs: Any,