Skip to content

Commit b42d9cf

Browse files
committed
feat: track token provenance
1 parent 0402b98 commit b42d9cf

4 files changed

Lines changed: 26 additions & 14 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from model2vec.distill.inference import create_embeddings
1515
from model2vec.distill.tokenizer import replace_vocabulary
16-
from model2vec.distill.utils import select_optimal_device
16+
from model2vec.distill.utils import Token, select_optimal_device
1717
from model2vec.model import StaticModel
1818
from model2vec.quantization import DType, quantize_embeddings
1919

model2vec/distill/inference.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from transformers import PreTrainedModel, PreTrainedTokenizerFast
1515
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
1616

17-
from model2vec.distill.utils import filter_vocabulary_by_regex
17+
from model2vec.distill.utils import Token, filter_vocabulary_by_regex
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -35,7 +35,7 @@ def create_embeddings(
3535
tokens: list[str],
3636
device: str,
3737
token_remove_regex: re.Pattern | None,
38-
) -> tuple[list[str], np.ndarray]:
38+
) -> tuple[list[Token], np.ndarray]:
3939
"""
4040
Create output embeddings for a bunch of tokens using a pretrained model.
4141
@@ -55,7 +55,7 @@ def create_embeddings(
5555
out_weights: np.ndarray
5656
intermediate_weights: list[np.ndarray] = []
5757

58-
out_tokens = []
58+
out_tokens: list[Token] = []
5959
tokenized: list[torch.Tensor] = []
6060
pad_token = tokenizer.special_tokens_map.get("pad_token")
6161
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
@@ -89,7 +89,8 @@ def create_embeddings(
8989
eos = torch.full([len(ids)], fill_value=eos_token_id)
9090

9191
tokenized.extend(torch.stack([bos, ids, eos], dim=1))
92-
out_tokens.extend(tokenizer.convert_ids_to_tokens(ids))
92+
subword_tokens = [Token(x, True) for x in tokenizer.convert_ids_to_tokens(ids.tolist())]
93+
out_tokens.extend(subword_tokens)
9394

9495
tokenized.extend([tokenizer.encode_plus(token, return_tensors="pt")["input_ids"][0] for token in tokens])
9596

@@ -119,7 +120,7 @@ def create_embeddings(
119120

120121
# Sort the output back to the original order
121122
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
122-
out_tokens.extend(tokens)
123+
out_tokens.extend([Token(x, False) for x in tokens])
123124
out_weights = np.stack(intermediate_weights)
124125

125126
return out_tokens, out_weights

model2vec/distill/tokenizer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from tokenizers import Tokenizer
88

9+
from model2vec.distill.utils import Token
10+
911
logger = logging.getLogger(__name__)
1012

1113

@@ -17,7 +19,7 @@
1719
}
1820

1921

20-
def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[str]) -> list[str]:
22+
def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[str]:
2123
"""
2224
Apply pre-tokenization to vocabulary tokens if a pre-tokenizer is present.
2325
@@ -33,14 +35,14 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[str]) -> list[st
3335

3436
if tokenizer.pre_tokenizer is not None:
3537
for token in tokens:
36-
if token in current_tokenizer_vocab:
37-
pre_tokenized_tokens.append(token)
38+
if token.is_subword:
39+
pre_tokenized_tokens.append(token.form)
3840
else:
3941
# We know 100% sure that all pretokenized tokens will have length 1.
40-
pretokenized_tokens, _ = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(f" {token}"))
42+
pretokenized_tokens, _ = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(f" {token.form}"))
4143
pre_tokenized_tokens.append(pretokenized_tokens[-1])
4244
else:
43-
pre_tokenized_tokens = tokens
45+
pre_tokenized_tokens = [token.form for token in tokens]
4446

4547
return pre_tokenized_tokens
4648

@@ -106,7 +108,7 @@ def _make_new_merges_from_vocab(
106108

107109

108110
def replace_vocabulary(
109-
tokenizer: Tokenizer, new_vocabulary: list[str], unk_token: str | None, pad_token: str | None
111+
tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
110112
) -> Tokenizer:
111113
"""Replace the vocabulary of a tokenizer with a new one."""
112114
tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str())
@@ -139,8 +141,8 @@ def replace_vocabulary(
139141
vocab = tokenizer_json["model"]["vocab"]
140142
unk_token = vocab[unk_id][0] if unk_id is not None else None
141143
current_probas = dict(tokenizer_json["model"]["vocab"])
142-
lowest_proba = min(current_probas.values())
143-
new_probas = {word: current_probas.get(word, lowest_proba) for word in pre_tokenized_tokens}
144+
avg_proba = sum(current_probas.values()) / len(current_probas)
145+
new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens}
144146
tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True)
145147

146148
tokens, _ = zip(*tokenizer_json["model"]["vocab"])

model2vec/distill/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
from __future__ import annotations
22

33
import re
4+
from dataclasses import dataclass
45
from logging import getLogger
56

67
import torch
78

89
logger = getLogger(__name__)
910

1011

12+
@dataclass
13+
class Token:
14+
"""A class to represent a token."""
15+
16+
form: str
17+
is_subword: bool
18+
19+
1120
def select_optimal_device(device: str | None) -> str:
1221
"""
1322
Guess what your optimal device should be based on backend availability.

0 commit comments

Comments
 (0)