diff --git a/pyrit/datasets/modality_test_assets/README.md b/pyrit/datasets/modality_test_assets/README.md new file mode 100644 index 000000000..98d77b08b --- /dev/null +++ b/pyrit/datasets/modality_test_assets/README.md @@ -0,0 +1,11 @@ +# Modality Test Assets + +Benign, minimal test files used by `pyrit.prompt_target.modality_verification` to +verify which modalities a target actually supports at runtime. + +- **test_image.png** — 1×1 white pixel PNG +- **test_audio.wav** — TTS-generated speech: "raccoons are extraordinary creatures" +- **test_video.mp4** — 1-frame, 16×16 solid color video + +These are intentionally simple and non-controversial so they won't be blocked by +content filters during modality verification. diff --git a/pyrit/datasets/modality_test_assets/test_audio.wav b/pyrit/datasets/modality_test_assets/test_audio.wav new file mode 100644 index 000000000..681ed1e1c Binary files /dev/null and b/pyrit/datasets/modality_test_assets/test_audio.wav differ diff --git a/pyrit/datasets/modality_test_assets/test_image.png b/pyrit/datasets/modality_test_assets/test_image.png new file mode 100644 index 000000000..94381b429 Binary files /dev/null and b/pyrit/datasets/modality_test_assets/test_image.png differ diff --git a/pyrit/datasets/modality_test_assets/test_video.mp4 b/pyrit/datasets/modality_test_assets/test_video.mp4 new file mode 100644 index 000000000..1121fb0ab Binary files /dev/null and b/pyrit/datasets/modality_test_assets/test_video.mp4 differ diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 586913eca..e36ec9a44 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -13,7 +13,7 @@ from pyrit.auth import AzureStorageAuth from pyrit.common import default_values from pyrit.identifiers import ComponentIdentifier -from pyrit.models import Message, construct_response_from_request +from pyrit.models import Message, PromptDataType, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -49,6 +49,12 @@ class AzureBlobStorageTarget(PromptTarget): AZURE_STORAGE_CONTAINER_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_CONTAINER_URL" SAS_TOKEN_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_SAS_TOKEN" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["url"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["url"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 24c49299d..ab38b030e 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -17,6 +17,7 @@ from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer from pyrit.models import ( Message, + PromptDataType, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget @@ -40,6 +41,9 @@ class AzureMLChatTarget(PromptChatTarget): endpoint_uri_environment_variable: str = "AZURE_ML_MANAGED_ENDPOINT" api_key_environment_variable: str = "AZURE_ML_KEY" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 902d3c10b..664456ceb 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -7,7 +7,7 @@ from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import Message +from pyrit.models import Message, PromptDataType logger = logging.getLogger(__name__) @@ -26,6 +26,17 @@ class PromptTarget(Identifiable): #: An empty list implies that the prompt target supports all converters. supported_converters: List[Any] + #: Set of supported input modality combinations. + #: Each frozenset represents a valid combination of modalities that can be sent together. + #: For example: {frozenset(["text"]), frozenset(["text", "image_path"])} + #: means the target supports either text-only OR text+image combinations. + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + + #: Set of supported output modality combinations. + #: Each frozenset represents a valid combination of modalities that can be returned. + #: Most targets currently only support text output. + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + _identifier: Optional[ComponentIdentifier] = None def __init__( @@ -78,6 +89,52 @@ def _validate_request(self, *, message: Message) -> None: message: The message to validate. """ + def input_modality_supported(self, modalities: set[PromptDataType]) -> bool: + """ + Check if a specific combination of input modalities is supported. + + Args: + modalities: Set of modality types to check (e.g., {"text", "image_path"}) + + Returns: + True if this exact combination is supported, False otherwise + """ + modalities_frozen = frozenset(modalities) + return modalities_frozen in self.SUPPORTED_INPUT_MODALITIES + + def output_modality_supported(self, modalities: set[PromptDataType]) -> bool: + """ + Check if a specific combination of output modalities is supported. + Most targets only support text output currently. + + Args: + modalities: Set of modality types to check + + Returns: + True if this exact combination is supported, False otherwise + """ + modalities_frozen = frozenset(modalities) + return modalities_frozen in self.SUPPORTED_OUTPUT_MODALITIES + + async def verify_actual_modalities(self) -> set[frozenset[PromptDataType]]: + """ + Verify what modalities this target actually supports at runtime. + + This optional verification tests the target with minimal requests to determine + actual capabilities, which may be a subset of the static API declarations. + + Returns: + Set of actually supported input modality combinations + + Example: + # Check what a specific OpenAI model actually supports + actual = await target.verify_actual_modalities() + # Returns: {frozenset(["text"])} or {frozenset(["text"]), frozenset(["text", "image_path"])} + """ + from pyrit.prompt_target.modality_verification import verify_target_modalities + + return await verify_target_modalities(self) + def set_model_name(self, *, model_name: str) -> None: """ Set the model name for this target. diff --git a/pyrit/prompt_target/crucible_target.py b/pyrit/prompt_target/crucible_target.py index 1f1fe974c..8d5dfa59c 100644 --- a/pyrit/prompt_target/crucible_target.py +++ b/pyrit/prompt_target/crucible_target.py @@ -12,7 +12,7 @@ handle_bad_request_exception, pyrit_target_retry, ) -from pyrit.models import Message, construct_response_from_request +from pyrit.models import Message, PromptDataType, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -24,6 +24,9 @@ class CrucibleTarget(PromptTarget): API_KEY_ENVIRONMENT_VARIABLE: str = "CRUCIBLE_API_KEY" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 5e3e89935..42153c53d 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -8,7 +8,7 @@ from pyrit.common import net_utility from pyrit.identifiers import ComponentIdentifier -from pyrit.models import Message, construct_response_from_request +from pyrit.models import Message, PromptDataType, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -38,6 +38,9 @@ class GandalfLevel(enum.Enum): class GandalfTarget(PromptTarget): """A prompt target for the Gandalf security challenge.""" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index 50ca68a88..4bfdeaff3 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -13,6 +13,7 @@ from pyrit.models import ( Message, MessagePiece, + PromptDataType, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -39,6 +40,9 @@ class HTTPTarget(PromptTarget): httpx_client_kwargs: (dict): additional keyword arguments to pass to the HTTP client """ + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, http_request: str, diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 6eff804f6..9fd47fdad 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -18,7 +18,7 @@ from pyrit.common.download_hf_model import download_specific_files from pyrit.exceptions import EmptyResponseException, pyrit_target_retry from pyrit.identifiers import ComponentIdentifier -from pyrit.models import Message, construct_response_from_request +from pyrit.models import Message, PromptDataType, construct_response_from_request from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -34,6 +34,12 @@ class HuggingFaceChatTarget(PromptChatTarget): Inherits from PromptTarget to comply with the current design standards. """ + #: HuggingFace targets typically only support text input for now + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + + #: HuggingFace targets typically only support text output for now + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + # Class-level cache for model and tokenizer _cached_model = None _cached_tokenizer = None diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index a21c87365..42735c841 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -6,7 +6,7 @@ from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.identifiers import ComponentIdentifier -from pyrit.models import Message, construct_response_from_request +from pyrit.models import Message, PromptDataType, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p @@ -20,6 +20,9 @@ class HuggingFaceEndpointTarget(PromptTarget): Inherits from PromptTarget to comply with the current design standards. """ + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/modality_verification.py b/pyrit/prompt_target/modality_verification.py new file mode 100644 index 000000000..193e1131b --- /dev/null +++ b/pyrit/prompt_target/modality_verification.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Optional modality verification system for prompt targets. + +This module provides runtime modality discovery to determine what modalities +a specific target actually supports, beyond what the API declares as possible. + +Usage: + from pyrit.prompt_target.modality_verification import verify_target_modalities + + # Get static API modalities + api_modalities = target.SUPPORTED_INPUT_MODALITIES + + # Optionally verify actual model modalities + actual_modalities = await verify_target_modalities(target) +""" + +import logging +import os +from typing import Optional + +from pyrit.common.path import DATASETS_PATH +from pyrit.models import Message, MessagePiece, PromptDataType +from pyrit.prompt_target.common.prompt_target import PromptTarget + +logger = logging.getLogger(__name__) + +# Path to the assets directory containing test files for modality verification +_ASSETS_DIR = DATASETS_PATH / "modality_test_assets" + +# Mapping from PromptDataType to test asset filenames +_TEST_ASSETS: dict[str, str] = { + "image_path": str(_ASSETS_DIR / "test_image.png"), + "audio_path": str(_ASSETS_DIR / "test_audio.wav"), + "video_path": str(_ASSETS_DIR / "test_video.mp4"), +} + + +async def verify_target_modalities( + target: PromptTarget, + test_modalities: Optional[set[frozenset[PromptDataType]]] = None, +) -> set[frozenset[PromptDataType]]: + """ + Verify which modality combinations a target actually supports. + + This function tests the target with minimal requests to determine actual + modalities, trimming down from the static API declarations. + + Args: + target: The prompt target to test + test_modalities: Specific modalities to test (defaults to target's declared modalities) + + Returns: + Set of actually supported input modality combinations + + Example: + actual = await verify_target_modalities(openai_target) + # Returns: {frozenset(["text"])} or {frozenset(["text"]), frozenset(["text", "image_path"])} + """ + if test_modalities is None: + test_modalities = target.SUPPORTED_INPUT_MODALITIES + + verified_modalities: set[frozenset[PromptDataType]] = set() + + for modality_combination in test_modalities: + try: + is_supported = await _test_modality_combination(target, modality_combination) + if is_supported: + verified_modalities.add(modality_combination) + except Exception as e: + logger.info(f"Failed to verify {modality_combination}: {e}") + + return verified_modalities + + +async def _test_modality_combination( + target: PromptTarget, + modalities: frozenset[PromptDataType], +) -> bool: + """ + Test a specific modality combination with a minimal API request. + + Args: + target: The target to test + modalities: The combination of modalities to test + + Returns: + True if the combination is supported, False otherwise + """ + test_message = _create_test_message(modalities) + + try: + responses = await target.send_prompt_async(message=test_message) + + # Check if the response itself indicates an error + for response in responses: + for piece in response.message_pieces: + if piece.response_error != "none": + logger.info(f"Modality {modalities} returned error response: {piece.converted_value}") + return False + + return True + + except Exception as e: + logger.info(f"Modality {modalities} not supported: {e}") + return False + + +def _create_test_message(modalities: frozenset[PromptDataType]) -> Message: + """ + Create a minimal test message for the specified modalities. + + Args: + modalities: The modalities to include in the test message + + Returns: + A Message object with minimal content for each requested modality + + Raises: + FileNotFoundError: If a required test asset file is missing + ValueError: If a modality has no configured test asset or no pieces could be created + """ + pieces: list[MessagePiece] = [] + conversation_id = "modality-verification-test" + + for modality in modalities: + if modality == "text": + pieces.append( + MessagePiece( + role="user", + original_value="test", + original_value_data_type="text", + conversation_id=conversation_id, + ) + ) + elif modality in _TEST_ASSETS: + asset_path = _TEST_ASSETS[modality] + if not os.path.isfile(asset_path): + raise FileNotFoundError(f"Test asset not found for modality '{modality}': {asset_path}") + pieces.append( + MessagePiece( + role="user", + original_value=asset_path, + original_value_data_type=modality, + conversation_id=conversation_id, + ) + ) + else: + raise ValueError(f"No test asset configured for modality: {modality}") + + if not pieces: + raise ValueError(f"Could not create test message for modalities: {modalities}") + + return Message(pieces) diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index f660939b5..39ef7bb92 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -18,6 +18,7 @@ DataTypeSerializer, Message, MessagePiece, + PromptDataType, construct_response_from_request, data_serializer_factory, ) @@ -62,6 +63,20 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): """ + #: OpenAI Chat API supports these input modality combinations + #: This represents what the API can handle, not what specific models support + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), # All models support text-only + frozenset(["text", "image_path"]), # API supports vision when model does + frozenset(["text", "audio_path"]), # API supports audio input when model does + } + + #: OpenAI Chat API output modalities + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), # Currently only text output + frozenset(["audio_path"]), # Audio output when audio_response_config is set + } + def __init__( self, *, diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 7731cdb91..00458aae5 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -8,7 +8,7 @@ pyrit_target_retry, ) from pyrit.identifiers import ComponentIdentifier -from pyrit.models import Message, construct_response_from_request +from pyrit.models import Message, PromptDataType, construct_response_from_request from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.openai.openai_target import OpenAITarget @@ -18,6 +18,9 @@ class OpenAICompletionTarget(OpenAITarget): """A prompt target for OpenAI completion endpoints.""" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, max_tokens: Optional[int] = None, diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index d0caa44e1..07fccf3eb 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -13,6 +13,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( Message, + PromptDataType, construct_response_from_request, data_serializer_factory, ) @@ -25,6 +26,14 @@ class OpenAIImageTarget(OpenAITarget): """A target for image generation or editing using OpenAI's image models.""" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["text", "image_path"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["image_path"]), + } + # Maximum number of image inputs supported by the OpenAI image API _MAX_INPUT_IMAGES = 16 diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 8c4b98d7e..54bbf010e 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -18,6 +18,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( Message, + PromptDataType, construct_response_from_request, data_serializer_factory, ) @@ -66,6 +67,16 @@ class RealtimeTarget(OpenAITarget): and https://platform.openai.com/docs/guides/realtime-websocket """ + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["text", "audio_path"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["audio_path"]), + frozenset(["text", "audio_path"]), + } + def __init__( self, *, diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 1eef3f49b..637327786 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -71,6 +71,12 @@ class OpenAIResponseTarget(OpenAITarget, PromptChatTarget): https://platform.openai.com/docs/api-reference/responses/create """ + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["text", "image_path"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, *, @@ -223,7 +229,7 @@ async def _construct_input_item_from_piece(self, piece: MessagePiece) -> Dict[st } if piece.converted_value_data_type == "image_path": data_url = await convert_local_image_to_data_url(piece.converted_value) - return {"type": "input_image", "image_url": {"url": data_url}} + return {"type": "input_image", "image_url": data_url} raise ValueError(f"Unsupported piece type for inline content: {piece.converted_value_data_type}") async def _build_input_for_multi_modal_async(self, conversation: MutableSequence[Message]) -> List[Dict[str, Any]]: diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 610bf68fd..fbd2468e6 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -10,6 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( Message, + PromptDataType, construct_response_from_request, data_serializer_factory, ) @@ -26,6 +27,9 @@ class OpenAITTSTarget(OpenAITarget): """A prompt target for OpenAI Text-to-Speech (TTS) endpoints.""" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["audio_path"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 3e915a37d..fd07816e9 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -16,6 +16,7 @@ DataTypeSerializer, Message, MessagePiece, + PromptDataType, construct_response_from_request, data_serializer_factory, ) @@ -48,6 +49,14 @@ class OpenAIVideoTarget(OpenAITarget): Supported image formats for text+image-to-video: JPEG, PNG, WEBP """ + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["text", "image_path"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["video_path"]), + } + SUPPORTED_RESOLUTIONS: list[VideoSize] = ["720x1280", "1280x720", "1024x1792", "1792x1024"] SUPPORTED_DURATIONS: list[VideoSeconds] = ["4", "8", "12"] SUPPORTED_IMAGE_FORMATS: list[str] = ["image/jpeg", "image/png", "image/webp"] diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 73fe98bc4..2f822f08e 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -78,6 +78,16 @@ class PlaywrightCopilotTarget(PromptTarget): # Supported data types SUPPORTED_DATA_TYPES = {"text", "image_path"} + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["text", "image_path"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["image_path"]), + frozenset(["text", "image_path"]), + } + # Placeholder text constants PLACEHOLDER_GENERATING_RESPONSE: str = "generating response" PLACEHOLDER_GENERATING: str = "generating" diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index 6e4fac6a7..06e4cf7ef 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -5,6 +5,7 @@ from pyrit.models import ( Message, + PromptDataType, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -52,6 +53,12 @@ class PlaywrightTarget(PromptTarget): # Supported data types SUPPORTED_DATA_TYPES = {"text", "image_path"} + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["text", "image_path"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 9b27b931d..c3f04d910 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -10,6 +10,7 @@ from pyrit.models import ( Message, MessagePiece, + PromptDataType, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -48,6 +49,9 @@ class PromptShieldTarget(PromptTarget): ENDPOINT_URI_ENVIRONMENT_VARIABLE: str = "AZURE_CONTENT_SAFETY_API_ENDPOINT" API_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_CONTENT_SAFETY_API_KEY" + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + _endpoint: str _api_key: str | Callable[[], str] | None _api_version: str diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index ecc9bf201..eb8ee68c2 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import IO -from pyrit.models import Message, MessagePiece +from pyrit.models import Message, MessagePiece, PromptDataType from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -20,6 +20,12 @@ class TextTarget(PromptTarget): but enter them manually. """ + #: Text targets only support text input + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + + #: Text targets only support text output + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + def __init__( self, *, diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index f4d5105b7..05f95c41a 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -20,7 +20,7 @@ pyrit_target_retry, ) from pyrit.identifiers import ComponentIdentifier -from pyrit.models import DataTypeSerializer, Message, MessagePiece, construct_response_from_request +from pyrit.models import DataTypeSerializer, Message, MessagePiece, PromptDataType, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute logger = logging.getLogger(__name__) @@ -71,6 +71,13 @@ class WebSocketCopilotTarget(PromptTarget): """ SUPPORTED_DATA_TYPES = {"text", "image_path"} + + SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = { + frozenset(["text"]), + frozenset(["text", "image_path"]), + } + SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])} + RESPONSE_TIMEOUT_SECONDS: int = 60 CONNECTION_TIMEOUT_SECONDS: int = 30 diff --git a/tests/integration/targets/test_entra_auth_targets.py b/tests/integration/targets/test_entra_auth_targets.py index bf068bfb2..84a64228f 100644 --- a/tests/integration/targets/test_entra_auth_targets.py +++ b/tests/integration/targets/test_entra_auth_targets.py @@ -11,7 +11,7 @@ get_azure_openai_auth, get_azure_token_provider, ) -from pyrit.common.path import HOME_PATH +from pyrit.common.path import DATASETS_PATH, HOME_PATH from pyrit.executor.attack import PromptSendingAttack from pyrit.models import Message, MessagePiece from pyrit.prompt_target import ( @@ -237,6 +237,51 @@ async def test_openai_responses_target_entra_auth(sqlite_instance, endpoint, mod assert result.last_response is not None +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("endpoint", "model_name"), + [ + ("OPENAI_RESPONSES_ENDPOINT", "OPENAI_RESPONSES_MODEL"), + ("AZURE_OPENAI_GPT41_RESPONSES_ENDPOINT", "AZURE_OPENAI_GPT41_RESPONSES_MODEL"), + ("AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT", "AZURE_OPENAI_GPT5_MODEL"), + ], +) +async def test_openai_responses_target_entra_auth_image(sqlite_instance, endpoint, model_name): + """Verify Response API image input works with Entra auth (image_url as plain string).""" + endpoint_value = os.environ[endpoint] + args = { + "endpoint": endpoint_value, + "model_name": os.environ[model_name], + "api_key": get_azure_openai_auth(endpoint_value), + } + + target = OpenAIResponseTarget(**args) + + conv_id = str(uuid.uuid4()) + test_image = str(DATASETS_PATH / "modality_test_assets" / "test_image.png") + + text_piece = MessagePiece( + role="user", + original_value="Describe this image briefly.", + original_value_data_type="text", + conversation_id=conv_id, + ) + image_piece = MessagePiece( + role="user", + original_value=test_image, + original_value_data_type="image_path", + conversation_id=conv_id, + ) + message = Message([text_piece, image_piece]) + + result = await target.send_prompt_async(message=message) + assert result is not None + assert len(result) >= 1 + response_text = result[0].message_pieces[-1].converted_value + assert response_text is not None + assert len(response_text) > 0 + + @pytest.mark.asyncio @pytest.mark.parametrize( ("endpoint", "model_name"), diff --git a/tests/integration/targets/test_modality_verification_integration.py b/tests/integration/targets/test_modality_verification_integration.py new file mode 100644 index 000000000..4ce245cb9 --- /dev/null +++ b/tests/integration/targets/test_modality_verification_integration.py @@ -0,0 +1,228 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Integration tests for modality verification across popular targets. + +These tests call verify_target_modalities() against real endpoints to confirm +that modality detection works end-to-end. Each target kind is represented at +least once. +""" + +import os + +import pytest + +from pyrit.prompt_target import ( + OpenAIChatTarget, + OpenAIImageTarget, + OpenAIResponseTarget, + OpenAITTSTarget, + OpenAIVideoTarget, + TextTarget, +) +from pyrit.prompt_target.modality_verification import verify_target_modalities + + +def _get_required_env_var(env_var_name: str) -> str: + value = os.getenv(env_var_name) + if not value: + raise ValueError(f"Environment variable {env_var_name} is not set.") + return value + + +# --------------------------------------------------------------------------- +# TextTarget – no credentials needed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_text_target(sqlite_instance): + """TextTarget supports text-only. Verification should confirm this without any API call.""" + target = TextTarget() + + result = await verify_target_modalities(target) + # TextTarget.send_prompt_async writes to a stream, so the text modality should succeed + assert frozenset(["text"]) in result + + +# --------------------------------------------------------------------------- +# OpenAI Chat – vision-capable model (e.g. gpt-4o) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_openai_chat_vision(sqlite_instance): + """A vision-capable OpenAI model should support text and text+image.""" + endpoint = _get_required_env_var("AZURE_OPENAI_GPT4O_ENDPOINT") + api_key = _get_required_env_var("AZURE_OPENAI_GPT4O_KEY") + model_name = _get_required_env_var("AZURE_OPENAI_GPT4O_MODEL") + + target = OpenAIChatTarget( + endpoint=endpoint, + api_key=api_key, + model_name=model_name, + ) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + assert frozenset(["text", "image_path"]) in result + + +# --------------------------------------------------------------------------- +# OpenAI Chat – text-only model (e.g. gpt-3.5) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_openai_chat_text_only(sqlite_instance): + """A text-only OpenAI model may still accept image input without error (ignoring it). + + Verification detects modalities that the API *rejects*, not what the model + truly understands. GPT-3.5 accepts images silently, so we only assert + that text is confirmed supported. + """ + endpoint = os.getenv("AZURE_OPENAI_GPT3_5_CHAT_ENDPOINT") + api_key = os.getenv("AZURE_OPENAI_GPT3_5_CHAT_KEY") + model_name = os.getenv("AZURE_OPENAI_GPT3_5_CHAT_MODEL") + + if not endpoint or not api_key or not model_name: + pytest.skip("GPT-3.5 env vars not set") + + target = OpenAIChatTarget( + endpoint=endpoint, + api_key=api_key, + model_name=model_name, + ) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + + +# --------------------------------------------------------------------------- +# OpenAI Chat – negative case: gpt-4 with text+audio should fail +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_openai_chat_no_audio(sqlite_instance): + """GPT-4 does not support audio input. Verification should exclude text+audio.""" + endpoint = os.getenv("AZURE_OPENAI_GPT4_CHAT_ENDPOINT") + api_key = os.getenv("AZURE_OPENAI_GPT4_CHAT_KEY") + model_name = os.getenv("AZURE_OPENAI_GPT4_CHAT_MODEL") + + if not endpoint or not api_key or not model_name: + pytest.skip("GPT-4 env vars not set") + + target = OpenAIChatTarget( + endpoint=endpoint, + api_key=api_key, + model_name=model_name, + ) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + assert frozenset(["text", "audio_path"]) not in result + + +# --------------------------------------------------------------------------- +# OpenAI Response API – GPT-5 +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_openai_response_gpt5(sqlite_instance): + """GPT-5 on the Responses API should support text and text+image.""" + endpoint = os.getenv("AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT") + api_key = os.getenv("AZURE_OPENAI_GPT5_KEY") + model_name = os.getenv("AZURE_OPENAI_GPT5_MODEL") + + if not endpoint or not api_key or not model_name: + pytest.skip("GPT-5 Responses env vars not set") + + target = OpenAIResponseTarget( + endpoint=endpoint, + api_key=api_key, + model_name=model_name, + ) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + assert frozenset(["text", "image_path"]) in result + + +# --------------------------------------------------------------------------- +# OpenAI Image API – gpt-image +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_openai_image(sqlite_instance): + """Image target should support text (generation) and text+image (editing).""" + endpoint = os.getenv("OPENAI_IMAGE_ENDPOINT2") + api_key = os.getenv("OPENAI_IMAGE_API_KEY2") + model_name = os.getenv("OPENAI_IMAGE_MODEL2") + + if not endpoint or not api_key or not model_name: + pytest.skip("Image API env vars not set") + + target = OpenAIImageTarget( + endpoint=endpoint, + api_key=api_key, + model_name=model_name, + ) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + + +# --------------------------------------------------------------------------- +# OpenAI Video API – Sora +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_openai_video_sora(sqlite_instance): + """Sora video target should support text-to-video.""" + endpoint = os.getenv("AZURE_OPENAI_VIDEO_ENDPOINT") + api_key = os.getenv("AZURE_OPENAI_VIDEO_KEY") + model_name = os.getenv("AZURE_OPENAI_VIDEO_MODEL") + + if not endpoint or not api_key or not model_name: + pytest.skip("Video/Sora env vars not set") + + target = OpenAIVideoTarget( + endpoint=endpoint, + api_key=api_key, + model_name=model_name, + ) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + + +# --------------------------------------------------------------------------- +# OpenAI TTS – text input only +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_verify_modalities_openai_tts(sqlite_instance): + """TTS target accepts text input. Verification should confirm text is supported.""" + endpoint = os.getenv("OPENAI_TTS_ENDPOINT") + api_key = os.getenv("OPENAI_TTS_KEY") + model_name = os.getenv("OPENAI_TTS_MODEL") + + if not endpoint or not api_key or not model_name: + pytest.skip("TTS env vars not set") + + target = OpenAITTSTarget( + endpoint=endpoint, + api_key=api_key, + model_name=model_name, + voice="alloy", + response_format="wav", + ) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result diff --git a/tests/integration/targets/test_openai_responses_gpt5.py b/tests/integration/targets/test_openai_responses_gpt5.py index 8f6bf3ae4..803bf10d5 100644 --- a/tests/integration/targets/test_openai_responses_gpt5.py +++ b/tests/integration/targets/test_openai_responses_gpt5.py @@ -10,7 +10,8 @@ import pytest # from pyrit.auth import get_azure_openai_auth -from pyrit.models import MessagePiece +from pyrit.common.path import DATASETS_PATH +from pyrit.models import Message, MessagePiece from pyrit.prompt_target import OpenAIResponseTarget @@ -145,6 +146,37 @@ async def test_openai_responses_gpt5_json_object(sqlite_instance, gpt5_args): # Can't assert more, since the failure could be due to a bad generation by the model +@pytest.mark.asyncio +async def test_openai_responses_gpt5_image(sqlite_instance, gpt5_args): + """GPT-5 on the Responses API should accept text+image input (image_url as plain string).""" + target = OpenAIResponseTarget(**gpt5_args) + + conv_id = str(uuid.uuid4()) + test_image = str(DATASETS_PATH / "modality_test_assets" / "test_image.png") + + text_piece = MessagePiece( + role="user", + original_value="Describe this image briefly.", + original_value_data_type="text", + conversation_id=conv_id, + ) + image_piece = MessagePiece( + role="user", + original_value=test_image, + original_value_data_type="image_path", + conversation_id=conv_id, + ) + message = Message([text_piece, image_piece]) + + result = await target.send_prompt_async(message=message) + assert result is not None + assert len(result) >= 1 + # The assistant should produce a text response describing the image + response_text = result[0].message_pieces[-1].converted_value + assert response_text is not None + assert len(response_text) > 0 + + @pytest.mark.asyncio async def test_openai_responses_gpt5_reasoning_effort(sqlite_instance, gpt5_args): target = OpenAIResponseTarget(**gpt5_args, reasoning_effort="low") diff --git a/tests/integration/targets/test_targets_and_secrets.py b/tests/integration/targets/test_targets_and_secrets.py index 8084e4a93..eeeb75e0b 100644 --- a/tests/integration/targets/test_targets_and_secrets.py +++ b/tests/integration/targets/test_targets_and_secrets.py @@ -9,7 +9,7 @@ import pytest from PIL import Image -from pyrit.common.path import HOME_PATH +from pyrit.common.path import DATASETS_PATH, HOME_PATH from pyrit.executor.attack import AttackExecutor, PromptSendingAttack from pyrit.models import Message, MessagePiece from pyrit.prompt_target import ( @@ -197,6 +197,69 @@ async def test_connect_required_openai_response_targets(sqlite_instance, endpoin await _assert_can_send_prompt(target, check_if_llm_interpreted_request=False) +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("endpoint", "api_key", "model_name"), + [ + ( + "PLATFORM_OPENAI_RESPONSES_ENDPOINT", + "PLATFORM_OPENAI_RESPONSES_KEY", + "PLATFORM_OPENAI_RESPONSES_MODEL", + ), + ( + "AZURE_OPENAI_RESPONSES_ENDPOINT", + "AZURE_OPENAI_RESPONSES_KEY", + "AZURE_OPENAI_RESPONSES_MODEL", + ), + ( + "AZURE_OPENAI_GPT41_RESPONSES_ENDPOINT", + "AZURE_OPENAI_GPT41_RESPONSES_KEY", + "AZURE_OPENAI_GPT41_RESPONSES_MODEL", + ), + ( + "AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT", + "AZURE_OPENAI_GPT5_KEY", + "AZURE_OPENAI_GPT5_MODEL", + ), + ], +) +async def test_connect_required_openai_response_targets_image(sqlite_instance, endpoint, api_key, model_name): + """Verify Response API targets accept text+image input (image_url as plain string).""" + endpoint_value = _get_required_env_var(endpoint) + api_key_value = _get_required_env_var(api_key) + model_name_value = _get_required_env_var(model_name) + + target = OpenAIResponseTarget( + endpoint=endpoint_value, + api_key=api_key_value, + model_name=model_name_value, + ) + + conv_id = str(uuid.uuid4()) + test_image = str(DATASETS_PATH / "modality_test_assets" / "test_image.png") + + text_piece = MessagePiece( + role="user", + original_value="Describe this image briefly.", + original_value_data_type="text", + conversation_id=conv_id, + ) + image_piece = MessagePiece( + role="user", + original_value=test_image, + original_value_data_type="image_path", + conversation_id=conv_id, + ) + message = Message([text_piece, image_piece]) + + result = await target.send_prompt_async(message=message) + assert result is not None + assert len(result) >= 1 + response_text = result[0].message_pieces[-1].converted_value + assert response_text is not None + assert len(response_text) > 0 + + @pytest.mark.asyncio @pytest.mark.parametrize( ("endpoint", "api_key", "model_name"), diff --git a/tests/unit/prompt_target/test_modality_support_clean.py b/tests/unit/prompt_target/test_modality_support_clean.py new file mode 100644 index 000000000..b511f0a3b --- /dev/null +++ b/tests/unit/prompt_target/test_modality_support_clean.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for modality support detection using set[frozenset[PromptDataType]] architecture. + +- SUPPORTED_INPUT_MODALITIES is set[frozenset[PromptDataType]] +- Each frozenset represents a valid combination of modalities +- Exact frozenset matching for precise modality detection +""" + +from unittest.mock import AsyncMock + +import pytest + +from pyrit.models import Message, MessagePiece, PromptDataType +from pyrit.prompt_target.modality_verification import ( + _create_test_message, + verify_target_modalities, +) +from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget +from pyrit.prompt_target.text_target import TextTarget + + +class TestModalitySupport: + """Test modality support detection with set[frozenset[PromptDataType]] architecture.""" + + def test_text_target_input_modalities(self, patch_central_database): + """Test TextTarget only supports text input.""" + target = TextTarget() + + assert target.input_modality_supported({"text"}) + assert not target.input_modality_supported({"text", "image_path"}) + assert not target.input_modality_supported({"image_path"}) + assert not target.input_modality_supported({"text", "audio_path"}) + + def test_text_target_output_modalities(self, patch_central_database): + """Test TextTarget only supports text output.""" + target = TextTarget() + + assert target.output_modality_supported({"text"}) + assert not target.output_modality_supported({"image_path"}) + assert not target.output_modality_supported({"text", "image_path"}) + + expected_output = {frozenset(["text"])} + assert target.SUPPORTED_OUTPUT_MODALITIES == expected_output + + def test_openai_static_api_declarations(self, patch_central_database): + """Test OpenAI uses static API modality declarations, not model-name pattern matching. + + All OpenAI models get the same static API declarations regardless of model name. + The optional verify_actual_modalities() trims these down at runtime. + """ + model_names = ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "some-future-model-xyz"] + + for model_name in model_names: + target = OpenAIChatTarget( + model_name=model_name, + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + expected_api_modalities = { + frozenset(["text"]), + frozenset(["text", "image_path"]), + frozenset(["text", "audio_path"]), + } + assert target.SUPPORTED_INPUT_MODALITIES == expected_api_modalities, ( + f"Model {model_name} should declare full API modalities" + ) + + assert target.input_modality_supported({"text"}) + assert target.input_modality_supported({"text", "image_path"}) + assert target.input_modality_supported({"text", "audio_path"}) + + def test_openai_unsupported_combinations(self, patch_central_database): + """Test that OpenAI rejects modality combinations not declared by the API.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + assert not target.input_modality_supported({"image_path"}) + assert not target.input_modality_supported({"audio_path"}) + assert not target.input_modality_supported({"text", "image_path", "audio_path"}) + + def test_frozenset_order_independence(self, patch_central_database): + """Test that modality checking is order-independent via frozenset matching.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + assert target.input_modality_supported({"image_path", "text"}) + assert target.input_modality_supported({"text", "image_path"}) + + def test_verify_actual_modalities_exists(self, patch_central_database): + """Test the optional runtime verification method exists.""" + target = TextTarget() + assert hasattr(target, "verify_actual_modalities") + + def test_modality_type_validation(self, patch_central_database): + """Test that modality checking works with PromptDataType literals.""" + target = TextTarget() + + text_type: PromptDataType = "text" + image_type: PromptDataType = "image_path" + audio_type: PromptDataType = "audio_path" + + assert target.input_modality_supported({text_type}) + assert not target.input_modality_supported({text_type, image_type}) + assert not target.input_modality_supported({audio_type}) + + def test_create_test_message_single_modality(self): + """Test that _create_test_message works for a single text modality.""" + msg = _create_test_message(frozenset(["text"])) + assert len(msg.message_pieces) == 1 + assert msg.message_pieces[0].original_value_data_type == "text" + assert msg.message_pieces[0].original_value == "test" + + def test_create_test_message_multimodal(self): + """Test that _create_test_message creates a valid Message for multimodal inputs. + + All pieces must share the same conversation_id and role for Message.validate() to pass. + """ + msg = _create_test_message(frozenset(["text", "image_path"])) + assert len(msg.message_pieces) == 2 + data_types = {p.original_value_data_type for p in msg.message_pieces} + assert data_types == {"text", "image_path"} + + # Verify all pieces share conversation_id (required by Message.validate) + conv_ids = {p.conversation_id for p in msg.message_pieces} + assert len(conv_ids) == 1 + + @pytest.mark.asyncio + async def test_verify_target_modalities_success(self, patch_central_database): + """Test verify_target_modalities returns supported modalities on success.""" + target = TextTarget() + + # Mock send_prompt_async to return a successful response + response_piece = MessagePiece( + role="assistant", + original_value="ok", + original_value_data_type="text", + response_error="none", + ) + mock_response = Message([response_piece]) + target.send_prompt_async = AsyncMock(return_value=[mock_response]) + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + + @pytest.mark.asyncio + async def test_verify_target_modalities_exception(self, patch_central_database): + """Test verify_target_modalities excludes modalities that raise exceptions.""" + target = TextTarget() + target.send_prompt_async = AsyncMock(side_effect=Exception("unsupported modality")) + + result = await verify_target_modalities(target) + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_verify_target_modalities_error_response(self, patch_central_database): + """Test verify_target_modalities excludes modalities returning error responses.""" + target = TextTarget() + + response_piece = MessagePiece( + role="assistant", + original_value="content filter triggered", + original_value_data_type="text", + response_error="blocked", + ) + mock_response = Message([response_piece]) + target.send_prompt_async = AsyncMock(return_value=[mock_response]) + + result = await verify_target_modalities(target) + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_verify_target_modalities_partial_support(self, patch_central_database): + """Test verify_target_modalities with a target that supports some but not all modalities.""" + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + # Text succeeds, text+image raises + async def selective_send(*, message): + types = {p.original_value_data_type for p in message.message_pieces} + if "image_path" in types: + raise Exception("image not supported by this model") + response_piece = MessagePiece( + role="assistant", + original_value="ok", + original_value_data_type="text", + response_error="none", + ) + return [Message([response_piece])] + + target.send_prompt_async = selective_send + + result = await verify_target_modalities(target) + assert frozenset(["text"]) in result + assert frozenset(["text", "image_path"]) not in result diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 1e4b95e0b..9ec0d7548 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -707,7 +707,7 @@ async def test_build_input_for_multi_modal_async_image_and_text(target: OpenAIRe assert result[0]["role"] == "user" assert result[0]["content"][0]["type"] == "input_text" assert result[0]["content"][1]["type"] == "input_image" - assert result[0]["content"][1]["image_url"]["url"].startswith("data:image/jpeg;base64,") + assert result[0]["content"][1]["image_url"].startswith("data:image/jpeg;base64,") @pytest.mark.asyncio