diff --git a/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py b/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py index 0086f162ea..81122fcc57 100644 --- a/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py +++ b/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch import pytest -from transformers import AutoTokenizer from pyrit.message_normalizer import TokenizerTemplateNormalizer from pyrit.models import Message, MessagePiece @@ -116,8 +115,18 @@ class TestNormalizeStringAsync: @pytest.fixture def chatml_tokenizer_normalizer(self): - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - return TokenizerTemplateNormalizer(tokenizer=tokenizer) + def _apply_chatml_template(messages, tokenize=False, add_generation_prompt=False): + """Simulate ChatML template formatting.""" + result = "" + for msg in messages: + result += f"<|{msg['role']}|>\n{msg['content']}\n" + if add_generation_prompt: + result += "<|assistant|>\n" + return result + + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.side_effect = _apply_chatml_template + return TokenizerTemplateNormalizer(tokenizer=mock_tokenizer) @pytest.mark.asyncio async def test_normalize_chatml(self, chatml_tokenizer_normalizer: TokenizerTemplateNormalizer):