Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ wheel>=0.43
pyyaml
py-cpuinfo
torch>=2.6.0
transformers>=5.0.0
transformers>=4.57.1

datasets>=2.15.0
numba>=0.62.0
Expand Down
10 changes: 7 additions & 3 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
# First Party
from instructlab.training.config import DataProcessArgs
from instructlab.training.logger import setup_root_logger
from instructlab.training.tokenizer_utils import get_sp_token, setup_tokenizer
from instructlab.training.tokenizer_utils import (
SPECIAL_TOKENS_KEY,
get_sp_token,
setup_tokenizer,
)
from instructlab.training.type_definitions import Message, ProcessedMessagesData
from instructlab.training.utils import log_rank_0, retrieve_chat_template

Expand Down Expand Up @@ -393,7 +397,7 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs):

# Adding after tokenizer setup as these are temp tokens, not to be saved
tokenizer.add_special_tokens(
{"extra_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<|MASK|>"]}
{SPECIAL_TOKENS_KEY: ["<|pretrain|>", "<|/pretrain|>", "<|MASK|>"]}
)

try:
Expand Down Expand Up @@ -1300,7 +1304,7 @@ def configure_tokenizer(model_path: str) -> PreTrainedTokenizer:
# Add special tokens for masking
tokenizer.add_special_tokens(
{
"extra_special_tokens": [
SPECIAL_TOKENS_KEY: [
UNMASK_BEGIN_TOKEN,
UNMASK_END_TOKEN,
UNMASK_REASONING_BEGIN_TOKEN,
Expand Down
25 changes: 20 additions & 5 deletions src/instructlab/training/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@

# Third Party
from transformers import AutoTokenizer, PreTrainedTokenizer
import transformers

# First Party
from instructlab.training.utils import log_rank_0, retrieve_chat_template

# Transformers v5 renamed 'additional_special_tokens' to 'extra_special_tokens'
_TRANSFORMERS_V5 = int(transformers.__version__.split(".", maxsplit=1)[0]) >= 5
SPECIAL_TOKENS_KEY = (
"extra_special_tokens" if _TRANSFORMERS_V5 else "additional_special_tokens"
)


def get_extra_special_tokens(tokenizer: PreTrainedTokenizer) -> list[str]:
"""Get extra/additional special tokens, compatible with both transformers v4 and v5."""
if _TRANSFORMERS_V5:
return tokenizer.extra_special_tokens
return tokenizer.additional_special_tokens


def setup_tokenizer_with_existing_chat_template(
tokenizer: PreTrainedTokenizer,
Expand All @@ -19,16 +33,17 @@ def setup_tokenizer_with_existing_chat_template(
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})

# ensure the pad token is in the extra special tokens without duplicating anything else
current_special = get_extra_special_tokens(tokenizer)
new_tokens = []
if tokenizer.pad_token not in tokenizer.extra_special_tokens:
if tokenizer.pad_token not in current_special:
new_tokens.append(tokenizer.pad_token)
if tokenizer.eos_token not in tokenizer.extra_special_tokens:
if tokenizer.eos_token not in current_special:
new_tokens.append(tokenizer.eos_token)

# ensure the tokens are being sorted to prevent any issues
new_tokens = sorted(new_tokens)
extra_special_tokens = tokenizer.extra_special_tokens + new_tokens
tokenizer.add_special_tokens({"extra_special_tokens": extra_special_tokens})
extra_special_tokens = current_special + new_tokens
tokenizer.add_special_tokens({SPECIAL_TOKENS_KEY: extra_special_tokens})

# ensure the necessary tokens exist
assert len(get_sp_token(tokenizer, tokenizer.pad_token)) == 1, (
Expand All @@ -55,7 +70,7 @@ def setup_tokenizer_from_new_chat_template(
}
)
tokenizer.add_special_tokens(
{"extra_special_tokens": SPECIAL_TOKENS.get_tokens_to_add()}
{SPECIAL_TOKENS_KEY: SPECIAL_TOKENS.get_tokens_to_add()}
)
if getattr(tokenizer, "add_bos_token", False) or getattr(
tokenizer, "add_eos_token", False
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
unmask_sample,
wrap_masked_messages,
)
from instructlab.training.tokenizer_utils import SPECIAL_TOKENS_KEY
from instructlab.training.type_definitions import Message, ProcessedMessagesData


Expand Down Expand Up @@ -853,7 +854,7 @@ def test_with_qwen_tokenizer(self):
# Add the unmask tokens to the tokenizer
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
SPECIAL_TOKENS_KEY: [
UNMASK_BEGIN_TOKEN,
UNMASK_END_TOKEN,
UNMASK_REASONING_BEGIN_TOKEN,
Expand Down Expand Up @@ -909,7 +910,7 @@ def test_with_phi_tokenizer(self):
# Add the unmask tokens to the tokenizer
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
SPECIAL_TOKENS_KEY: [
UNMASK_BEGIN_TOKEN,
UNMASK_END_TOKEN,
UNMASK_REASONING_BEGIN_TOKEN,
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_unmask_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
unmask_sample,
wrap_masked_messages,
)
from instructlab.training.tokenizer_utils import SPECIAL_TOKENS_KEY
from instructlab.training.type_definitions import Message


Expand Down Expand Up @@ -342,7 +343,7 @@ def test_tokenizer(self):
# Add the special tokens
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
SPECIAL_TOKENS_KEY: [
UNMASK_BEGIN_TOKEN,
UNMASK_END_TOKEN,
UNMASK_REASONING_BEGIN_TOKEN,
Expand Down Expand Up @@ -492,7 +493,7 @@ def real_tokenizer(self, request):
# Add the special unmask tokens
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
SPECIAL_TOKENS_KEY: [
UNMASK_BEGIN_TOKEN,
UNMASK_END_TOKEN,
UNMASK_REASONING_BEGIN_TOKEN,
Expand Down
Loading