55from logging import getLogger
66from pathlib import Path
77from tempfile import TemporaryDirectory
8- from typing import Any , Iterator , Union
8+ from typing import Any , Iterator , Sequence , Union , overload
99
1010import numpy as np
1111from joblib import delayed
@@ -117,7 +117,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
117117 subfolder = subfolder ,
118118 )
119119
120- def tokenize (self , sentences : list [str ], max_length : int | None = None ) -> list [list [int ]]:
120+ def tokenize (self , sentences : Sequence [str ], max_length : int | None = None ) -> list [list [int ]]:
121121 """
122122 Tokenize a list of sentences.
123123
@@ -245,9 +245,31 @@ def from_sentence_transformers(
245245 language = metadata .get ("language" ),
246246 )
247247
248+ @overload
248249 def encode_as_sequence (
249250 self ,
250- sentences : list [str ] | str ,
251+ sentences : str ,
252+ max_length : int | None = None ,
253+ batch_size : int = 1024 ,
254+ show_progress_bar : bool = False ,
255+ use_multiprocessing : bool = True ,
256+ multiprocessing_threshold : int = 10_000 ,
257+ ) -> np .ndarray : ...
258+
259+ @overload
260+ def encode_as_sequence (
261+ self ,
262+ sentences : list [str ],
263+ max_length : int | None = None ,
264+ batch_size : int = 1024 ,
265+ show_progress_bar : bool = False ,
266+ use_multiprocessing : bool = True ,
267+ multiprocessing_threshold : int = 10_000 ,
268+ ) -> list [np .ndarray ]: ...
269+
270+ def encode_as_sequence (
271+ self ,
272+ sentences : str | list [str ],
251273 max_length : int | None = None ,
252274 batch_size : int = 1024 ,
253275 show_progress_bar : bool = False ,
@@ -263,6 +285,9 @@ def encode_as_sequence(
263285 This is about twice as slow.
264286 Sentences that do not contain any tokens will be turned into an empty array.
265287
288+ NOTE: the input type is currently underspecified. The actual input type is `Sequence[str] | str`, but this
289+ is not possible to implement in python typing currently.
290+
266291 :param sentences: The list of sentences to encode.
267292 :param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
268293 If this is None, no truncation is done.
@@ -320,7 +345,7 @@ def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None
320345
321346 def encode (
322347 self ,
323- sentences : list [str ] | str ,
348+ sentences : Sequence [str ],
324349 show_progress_bar : bool = False ,
325350 max_length : int | None = 512 ,
326351 batch_size : int = 1024 ,
@@ -334,6 +359,9 @@ def encode(
334359 This function encodes a list of sentences by averaging the word embeddings of the tokens in the sentence.
335360 For ease of use, we don't batch sentences together.
336361
362+ NOTE: the return type is currently underspecified. In the case of a single string, this returns a 1D array,
363+ but in the case of a list of strings, this returns a 2D array. Not possible to implement in numpy currently.
364+
337365 :param sentences: The list of sentences to encode. You can also pass a single sentence.
338366 :param show_progress_bar: Whether to show the progress bar.
339367 :param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
@@ -378,7 +406,7 @@ def encode(
378406 return out_array [0 ]
379407 return out_array
380408
381- def _encode_batch (self , sentences : list [str ], max_length : int | None ) -> np .ndarray :
409+ def _encode_batch (self , sentences : Sequence [str ], max_length : int | None ) -> np .ndarray :
382410 """Encode a batch of sentences."""
383411 ids = self .tokenize (sentences = sentences , max_length = max_length )
384412 out : list [np .ndarray ] = []
@@ -396,7 +424,7 @@ def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndar
396424 return out_array
397425
398426 @staticmethod
399- def _batch (sentences : list [str ], batch_size : int ) -> Iterator [list [str ]]:
427+ def _batch (sentences : Sequence [str ], batch_size : int ) -> Iterator [Sequence [str ]]:
400428 """Batch the sentences into equal-sized."""
401429 return (sentences [i : i + batch_size ] for i in range (0 , len (sentences ), batch_size ))
402430
0 commit comments