diff --git a/pyrit/common/modality_discovery.py b/pyrit/common/modality_discovery.py new file mode 100644 index 000000000..8932d724a --- /dev/null +++ b/pyrit/common/modality_discovery.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Dict, List, Optional + +from pyrit.models.literals import PromptDataType + +logger = logging.getLogger(__name__) + + +def verify_target_capabilities( + target: Any, + test_modalities: Optional[List[str]] = None +) -> Dict[str, bool]: + """ + Optional utility to verify actual modality capabilities of a target via runtime testing. + + This is a verification tool that can be used to check if declared capabilities + match actual API behavior. It's completely optional and separate from the main + modality support system. + + Args: + target: The target to test (must have _async_client and model_name) + test_modalities: List of modalities to test (defaults to ["image", "audio"]) + + Returns: + Dict mapping modality names to actual support status + """ + if test_modalities is None: + test_modalities = ["image", "audio"] + + if not hasattr(target, '_async_client') or not hasattr(target, 'model_name'): + logger.warning(f"Target {type(target).__name__} doesn't support capability testing") + return {modality: False for modality in test_modalities} + + capabilities = {} + + try: + for modality in test_modalities: + capabilities[modality] = _test_single_modality(target, modality) + except Exception as e: + logger.warning(f"Capability verification failed for {target.model_name}: {e}") + return {modality: False for modality in test_modalities} + + logger.info(f"Verified capabilities for {target.model_name}: {capabilities}") + return capabilities + + +def _test_single_modality(target: Any, modality: str) -> bool: + """Test a single modality with minimal API request.""" + try: + if modality == "image": + # Minimal 1x1 PNG test + minimal_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + test_messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "Test image?"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{minimal_png}"}} + ] + }] + elif modality == "audio": + # Minimal audio test (placeholder) + test_messages = [{ + "role": "user", + "content": [{"type": "text", "text": "Test audio?"}] + }] + else: + return False + + # Try the request + async def _test(): + try: + await target._async_client.chat.completions.create( + model=target.model_name, + messages=test_messages, + max_tokens=1 + ) + return True + except Exception as e: + error_msg = str(e).lower() + # Check for modality-specific errors + if any(indicator in error_msg for indicator in [ + "does not support", "not supported", "vision is not supported", "invalid content type" + ]): + return False + return False # Any error = not supported + + # Run the async test + try: + loop = asyncio.get_running_loop() + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, _test()) + return future.result(timeout=10) + except RuntimeError: + return asyncio.run(_test()) + + except Exception: + return False \ No newline at end of file diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index 9d3117018..e8da2a932 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -40,6 +40,9 @@ class TargetIdentifier(Identifier): target_specific_params: Optional[Dict[str, Any]] = None """Additional target-specific parameters.""" + supports_conversation_history: bool = True + """Whether the target supports maintaining conversation history.""" + @classmethod def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIdentifier": """ diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 29ba2cb47..ebe839b62 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -8,6 +8,7 @@ from pyrit.identifiers import Identifiable, TargetIdentifier from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import Message +from pyrit.models.literals import PromptDataType logger = logging.getLogger(__name__) @@ -26,6 +27,16 @@ class PromptTarget(Identifiable[TargetIdentifier]): #: An empty list implies that the prompt target supports all converters. supported_converters: List[Any] + #: Set of supported input modality combinations. Each frozenset represents a valid + #: combination of modalities that can be sent together in a single request. + #: Example: {frozenset({"text"}), frozenset({"text", "image_path"})} means supports text-only OR text+image + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset({"text"})} + + #: Set of supported output modality combinations. Each frozenset represents a valid + #: combination of modalities that can be returned together in a single response. + #: Example: {frozenset({"text"})} means produces text-only outputs + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset({"text"})} + _identifier: Optional[TargetIdentifier] = None def __init__( @@ -152,3 +163,49 @@ def _build_identifier(self) -> TargetIdentifier: TargetIdentifier: The identifier for this prompt target. """ return self._create_identifier() + + def input_modality_supported(self, modalities: set[PromptDataType]) -> bool: + """ + Check if a specific combination of input modalities is supported by this target. + + Args: + modalities: The set of modalities to check together (e.g., {"text", "image_path"}). + + Returns: + bool: True if the exact combination is supported, False otherwise. + """ + modality_frozenset = frozenset(modalities) + supported_modalities = self.SUPPORTED_INPUT_MODALITIES # Works with both class attr and property + return modality_frozenset in supported_modalities + + def output_modality_supported(self, modalities: set[PromptDataType]) -> bool: + """ + Check if a specific combination of output modalities is supported by this target. + + Args: + modalities: The set of modalities to check together (e.g., {"text", "image_url"}). + + Returns: + bool: True if the exact combination is supported, False otherwise. + """ + return frozenset(modalities) in self.SUPPORTED_OUTPUT_MODALITIES + + @property + def supported_input_modalities(self) -> set[PromptDataType]: + """ + Get all individual input modalities supported by this target across all combinations. + + Returns: + set[PromptDataType]: Set of all individual modalities that appear in any supported combination. + """ + return set.union(*self.SUPPORTED_INPUT_MODALITIES) if self.SUPPORTED_INPUT_MODALITIES else set() + + @property + def supported_output_modalities(self) -> set[PromptDataType]: + """ + Get all individual output modalities supported by this target across all combinations. + + Returns: + set[PromptDataType]: Set of all individual modalities that appear in any supported combination. + """ + return set.union(*self.SUPPORTED_OUTPUT_MODALITIES) if self.SUPPORTED_OUTPUT_MODALITIES else set() diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index f320248e2..3ebe4c48c 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import asyncio import logging import os @@ -19,6 +21,7 @@ from pyrit.exceptions import EmptyResponseException, pyrit_target_retry from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request +from pyrit.models.literals import PromptDataType from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -34,6 +37,16 @@ class HuggingFaceChatTarget(PromptChatTarget): Inherits from PromptTarget to comply with the current design standards. """ + #: HuggingFace Chat targets support only text input and output + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset({"text"}) + } + + #: HuggingFace Chat targets produce only text outputs + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset({"text"}) + } + # Class-level cache for model and tokenizer _cached_model = None _cached_tokenizer = None diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 87ffa26f4..0a53f607b 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -1,7 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + +import asyncio import base64 +import concurrent.futures import json import logging from typing import Any, Dict, MutableSequence, Optional @@ -22,6 +26,7 @@ data_serializer_factory, ) from pyrit.models.json_response_config import _JsonResponseConfig +from pyrit.models.literals import PromptDataType from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig @@ -62,6 +67,96 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): """ + @property + def SUPPORTED_INPUT_MODALITIES(self) -> set[frozenset[PromptDataType]]: + """ + Get supported input modalities based on the OpenAI model. + + Uses intelligent pattern matching with fallback verification for unknown models. + For explicit verification, use verify_capabilities() method. + """ + model_lower = self.model_name.lower() + + # Text-only patterns (explicit) + if any(pattern in model_lower for pattern in [ + "gpt-3.5", "davinci", "curie", "babbage", "ada" + ]): + return {frozenset({"text"})} + + # Multimodal indicators (vision/image support) + multimodal_indicators = [ + "vision", # gpt-4-vision-preview, etc. + "gpt-4o", # gpt-4o, gpt-4o-mini, gpt-4o-2024-*, etc. + "gpt-4-turbo", # Often has vision + "gpt-5", # Future models likely multimodal + "gpt-4.5", # Hypothetical intermediate + "multimodal", # Explicit in name + "omni", # Omni-modal models + ] + + if any(indicator in model_lower for indicator in multimodal_indicators): + return { + frozenset({"text"}), + frozenset({"text", "image_path"}) + } + + # For unknown GPT-4+ models, assume multimodal (safer for newer models) + if "gpt-4" in model_lower and not any(old_pattern in model_lower for old_pattern in ["gpt-4-0314", "gpt-4-32k"]): + return { + frozenset({"text"}), + frozenset({"text", "image_path"}) + } + + # Conservative default for completely unknown models + return {frozenset({"text"})} + + @property + def SUPPORTED_OUTPUT_MODALITIES(self) -> set[frozenset[PromptDataType]]: + """OpenAI models currently only produce text outputs.""" + return {frozenset({"text"})} + + def verify_capabilities(self, use_as_fallback: bool = False) -> dict[str, bool]: + """ + Optional verification of actual capabilities via runtime testing. + + Args: + use_as_fallback: If True, updates SUPPORTED_INPUT_MODALITIES cache based on results + + Returns: + Dict mapping modality names to actual support status + """ + from pyrit.common.modality_discovery import verify_target_capabilities + capabilities = verify_target_capabilities(self) + + # Optional: Cache verified capabilities for unknown models + if use_as_fallback and not hasattr(self, '_verified_modalities'): + self._verified_modalities = capabilities + + return capabilities + + def get_verified_input_modalities(self) -> set[frozenset[PromptDataType]]: + """ + Get input modalities using runtime verification as fallback for unknown models. + More accurate but slower than static SUPPORTED_INPUT_MODALITIES. + """ + # Use cached verification if available + if hasattr(self, '_verified_modalities'): + capabilities = self._verified_modalities + else: + capabilities = self.verify_capabilities(use_as_fallback=True) + + # Convert capabilities to modality sets + modalities = {frozenset({"text"})} # Always support text + + if capabilities.get("image", False): + modalities.add(frozenset({"text", "image_path"})) + if capabilities.get("audio", False): + modalities.add(frozenset({"text", "audio_path"})) + if capabilities.get("video", False): + modalities.add(frozenset({"text", "video_path"})) + + return modalities + def __init__( self, *, diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index ecc9bf201..4d5a25109 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import csv import json import sys @@ -8,6 +10,7 @@ from typing import IO from pyrit.models import Message, MessagePiece +from pyrit.models.literals import PromptDataType from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -20,6 +23,16 @@ class TextTarget(PromptTarget): but enter them manually. """ + #: TextTarget supports only text input and output + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset({"text"}) + } + + #: TextTarget produces only text outputs + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset({"text"}) + } + def __init__( self, *, diff --git a/tests/unit/prompt_target/test_modality_support_simple.py b/tests/unit/prompt_target/test_modality_support_simple.py new file mode 100644 index 000000000..e7a9339d2 --- /dev/null +++ b/tests/unit/prompt_target/test_modality_support_simple.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import pytest +from unittest.mock import Mock, patch + +from pyrit.common.modality_discovery import verify_target_capabilities + + +class MockOpenAIChatTarget: + """Mock OpenAI target for testing without requiring environment variables.""" + + def __init__(self, model_name: str): + self.model_name = model_name + + @property + def SUPPORTED_INPUT_MODALITIES(self) -> set[frozenset[str]]: + """Static pattern matching for known multimodal models with future-proof heuristics.""" + model_lower = self.model_name.lower() + + # Text-only patterns (explicit) + if any(pattern in model_lower for pattern in [ + "gpt-3.5", "davinci", "curie", "babbage", "ada" + ]): + return {frozenset({"text"})} + + # Multimodal indicators (vision/image support) + multimodal_indicators = [ + "vision", # gpt-4-vision-preview, etc. + "gpt-4o", # gpt-4o, gpt-4o-mini, gpt-4o-2024-*, etc. + "gpt-4-turbo", # Often has vision + "gpt-5", # Future models likely multimodal + "gpt-4.5", # Hypothetical intermediate + "multimodal", # Explicit in name + "omni", # Omni-modal models + ] + + if any(indicator in model_lower for indicator in multimodal_indicators): + return { + frozenset({"text"}), + frozenset({"text", "image_path"}) + } + + # For unknown GPT-4+ models, assume multimodal (safer for newer models) + if "gpt-4" in model_lower and not any(old_pattern in model_lower for old_pattern in ["gpt-4-0314", "gpt-4-32k"]): + return { + frozenset({"text"}), + frozenset({"text", "image_path"}) + } + + # Conservative default for completely unknown models + return {frozenset({"text"})} + + @property + def SUPPORTED_OUTPUT_MODALITIES(self) -> set[frozenset[str]]: + """OpenAI models currently only produce text outputs.""" + return {frozenset({"text"})} + + def input_modality_supported(self, modality: frozenset[str]) -> bool: + """Check if input modality combination is supported.""" + return modality in self.SUPPORTED_INPUT_MODALITIES + + def output_modality_supported(self, modality: frozenset[str]) -> bool: + """Check if output modality combination is supported.""" + return modality in self.SUPPORTED_OUTPUT_MODALITIES + + +class TestSimpleModalitySupport: + """Test the simplified modality support system with static declarations.""" + + def test_gpt4o_static_modalities(self): + """Test that GPT-4o correctly declares multimodal support.""" + target = MockOpenAIChatTarget("gpt-4o") + + input_modalities = target.SUPPORTED_INPUT_MODALITIES + output_modalities = target.SUPPORTED_OUTPUT_MODALITIES + + # Should support text-only and text+image + assert frozenset({"text"}) in input_modalities + assert frozenset({"text", "image_path"}) in input_modalities + assert len(input_modalities) == 2 + + # Output is text-only + assert output_modalities == {frozenset({"text"})} + + def test_gpt35_static_modalities(self): + """Test that GPT-3.5 correctly declares text-only support.""" + target = MockOpenAIChatTarget("gpt-3.5-turbo") + + input_modalities = target.SUPPORTED_INPUT_MODALITIES + output_modalities = target.SUPPORTED_OUTPUT_MODALITIES + + # Should support only text + assert input_modalities == {frozenset({"text"})} + assert output_modalities == {frozenset({"text"})} + + def test_gpt4_vision_static_modalities(self): + """Test that GPT-4 Vision correctly declares multimodal support.""" + target = MockOpenAIChatTarget("gpt-4-vision-preview") + + input_modalities = target.SUPPORTED_INPUT_MODALITIES + + # Should support text-only and text+image + assert frozenset({"text"}) in input_modalities + assert frozenset({"text", "image_path"}) in input_modalities + + def test_optional_verification_utility(self): + """Test the optional verification utility function.""" + # Mock target with required attributes + mock_target = Mock() + mock_target._async_client = Mock() + mock_target.model_name = "gpt-4o" + + # Test the verification function + result = verify_target_capabilities(mock_target, ["image"]) + + # Should return a dict with the tested modality + assert isinstance(result, dict) + assert "image" in result + assert isinstance(result["image"], bool) + + def test_base_class_helper_methods(self): + """Test that base class helper methods work with static declarations.""" + target = MockOpenAIChatTarget("gpt-4o") + + # Test input modality checking + assert target.input_modality_supported(frozenset({"text"})) + assert target.input_modality_supported(frozenset({"text", "image_path"})) + assert not target.input_modality_supported(frozenset({"text", "audio_path"})) + + # Test output modality checking + assert target.output_modality_supported(frozenset({"text"})) + assert not target.output_modality_supported(frozenset({"audio_path"})) + + def test_model_pattern_matching(self): + """Test that model pattern matching works correctly with future-proof heuristics.""" + test_cases = [ + # Current multimodal models + ("gpt-4o", True), + ("gpt-4o-mini", True), + ("gpt-4o-2024-08-06", True), + ("gpt-4-vision-preview", True), + ("gpt-4-turbo", True), + ("gpt-4-turbo-2024-04-09", True), + + # Future models (should be detected) + ("gpt-5", True), + ("gpt-4.5-preview", True), + ("gpt-4-omni", True), + ("gpt-4-multimodal", True), + ("gpt-4-2024-12-01", True), # Unknown GPT-4 variants assumed multimodal + + # Explicit text-only + ("gpt-3.5-turbo", False), + ("gpt-3.5-turbo-16k", False), + ("text-davinci-003", False), + + # Old GPT-4 models (known text-only) + ("gpt-4-0314", False), + ("gpt-4-32k", False), + + # Unknown models (conservative default) + ("custom-model", False), + ("claude-3", False), + ] + + for model_name, should_support_image in test_cases: + target = MockOpenAIChatTarget(model_name) + input_modalities = target.SUPPORTED_INPUT_MODALITIES + + has_image_support = frozenset({"text", "image_path"}) in input_modalities + assert has_image_support == should_support_image, f"Failed for model: {model_name}" + + +class TestVerificationUtility: + """Test the optional verification utility function.""" + + def test_verify_missing_attributes(self): + """Test verification with target missing required attributes.""" + mock_target = Mock() + # Don't add _async_client or model_name + + result = verify_target_capabilities(mock_target, ["image"]) + + # Should return False for all modalities + assert result == {"image": False} + + def test_verify_default_modalities(self): + """Test verification with default modalities.""" + mock_target = Mock() + mock_target._async_client = Mock() + mock_target.model_name = "test-model" + + result = verify_target_capabilities(mock_target) + + # Should test default modalities (image, audio) + assert "image" in result + assert "audio" in result + assert len(result) == 2 \ No newline at end of file