Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pyrit/datasets/modality_test_assets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Modality Test Assets

Benign, minimal test files used by `pyrit.prompt_target.modality_verification` to
verify which modalities a target actually supports at runtime.

- **test_image.png** — 1×1 white pixel PNG
- **test_audio.wav** — TTS-generated speech: "raccoons are extraordinary creatures"
- **test_video.mp4** — 1-frame, 16×16 solid color video

These are intentionally simple and non-controversial so they won't be blocked by
content filters during modality verification.
Binary file not shown.
Binary file added pyrit/datasets/modality_test_assets/test_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pyrit/datasets/modality_test_assets/test_video.mp4
Binary file not shown.
8 changes: 7 additions & 1 deletion pyrit/prompt_target/azure_blob_storage_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyrit.auth import AzureStorageAuth
from pyrit.common import default_values
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.models import Message, PromptDataType, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute

Expand Down Expand Up @@ -49,6 +49,12 @@ class AzureBlobStorageTarget(PromptTarget):
AZURE_STORAGE_CONTAINER_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_CONTAINER_URL"
SAS_TOKEN_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_SAS_TOKEN"

SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {
frozenset(["text"]),
frozenset(["url"]),
}
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["url"])}

def __init__(
self,
*,
Expand Down
4 changes: 4 additions & 0 deletions pyrit/prompt_target/azure_ml_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer
from pyrit.models import (
Message,
PromptDataType,
construct_response_from_request,
)
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
Expand All @@ -40,6 +41,9 @@ class AzureMLChatTarget(PromptChatTarget):
endpoint_uri_environment_variable: str = "AZURE_ML_MANAGED_ENDPOINT"
api_key_environment_variable: str = "AZURE_ML_KEY"

SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

def __init__(
self,
*,
Expand Down
59 changes: 58 additions & 1 deletion pyrit/prompt_target/common/prompt_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyrit.identifiers import ComponentIdentifier, Identifiable
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import Message
from pyrit.models import Message, PromptDataType

logger = logging.getLogger(__name__)

Expand All @@ -26,6 +26,17 @@ class PromptTarget(Identifiable):
#: An empty list implies that the prompt target supports all converters.
supported_converters: List[Any]

#: Set of supported input modality combinations.
#: Each frozenset represents a valid combination of modalities that can be sent together.
#: For example: {frozenset(["text"]), frozenset(["text", "image_path"])}
#: means the target supports either text-only OR text+image combinations.
SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

#: Set of supported output modality combinations.
#: Each frozenset represents a valid combination of modalities that can be returned.
#: Most targets currently only support text output.
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

_identifier: Optional[ComponentIdentifier] = None

def __init__(
Expand Down Expand Up @@ -78,6 +89,52 @@ def _validate_request(self, *, message: Message) -> None:
message: The message to validate.
"""

def input_modality_supported(self, modalities: set[PromptDataType]) -> bool:
"""
Check if a specific combination of input modalities is supported.

Args:
modalities: Set of modality types to check (e.g., {"text", "image_path"})

Returns:
True if this exact combination is supported, False otherwise
"""
modalities_frozen = frozenset(modalities)
return modalities_frozen in self.SUPPORTED_INPUT_MODALITIES

def output_modality_supported(self, modalities: set[PromptDataType]) -> bool:
"""
Check if a specific combination of output modalities is supported.
Most targets only support text output currently.

Args:
modalities: Set of modality types to check

Returns:
True if this exact combination is supported, False otherwise
"""
modalities_frozen = frozenset(modalities)
return modalities_frozen in self.SUPPORTED_OUTPUT_MODALITIES

async def verify_actual_modalities(self) -> set[frozenset[PromptDataType]]:
"""
Verify what modalities this target actually supports at runtime.

This optional verification tests the target with minimal requests to determine
actual capabilities, which may be a subset of the static API declarations.

Returns:
Set of actually supported input modality combinations

Example:
# Check what a specific OpenAI model actually supports
actual = await target.verify_actual_modalities()
# Returns: {frozenset(["text"])} or {frozenset(["text"]), frozenset(["text", "image_path"])}
"""
from pyrit.prompt_target.modality_verification import verify_target_modalities

return await verify_target_modalities(self)

def set_model_name(self, *, model_name: str) -> None:
"""
Set the model name for this target.
Expand Down
5 changes: 4 additions & 1 deletion pyrit/prompt_target/crucible_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
handle_bad_request_exception,
pyrit_target_retry,
)
from pyrit.models import Message, construct_response_from_request
from pyrit.models import Message, PromptDataType, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute

Expand All @@ -24,6 +24,9 @@ class CrucibleTarget(PromptTarget):

API_KEY_ENVIRONMENT_VARIABLE: str = "CRUCIBLE_API_KEY"

SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

def __init__(
self,
*,
Expand Down
5 changes: 4 additions & 1 deletion pyrit/prompt_target/gandalf_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pyrit.common import net_utility
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.models import Message, PromptDataType, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute

Expand Down Expand Up @@ -38,6 +38,9 @@ class GandalfLevel(enum.Enum):
class GandalfTarget(PromptTarget):
"""A prompt target for the Gandalf security challenge."""

SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

def __init__(
self,
*,
Expand Down
4 changes: 4 additions & 0 deletions pyrit/prompt_target/http_target/http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pyrit.models import (
Message,
MessagePiece,
PromptDataType,
construct_response_from_request,
)
from pyrit.prompt_target.common.prompt_target import PromptTarget
Expand All @@ -39,6 +40,9 @@ class HTTPTarget(PromptTarget):
httpx_client_kwargs: (dict): additional keyword arguments to pass to the HTTP client
"""

SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

def __init__(
self,
http_request: str,
Expand Down
8 changes: 7 additions & 1 deletion pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pyrit.common.download_hf_model import download_specific_files
from pyrit.exceptions import EmptyResponseException, pyrit_target_retry
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.models import Message, PromptDataType, construct_response_from_request
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute

Expand All @@ -34,6 +34,12 @@ class HuggingFaceChatTarget(PromptChatTarget):
Inherits from PromptTarget to comply with the current design standards.
"""

#: HuggingFace targets typically only support text input for now
SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

#: HuggingFace targets typically only support text output for now
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

# Class-level cache for model and tokenizer
_cached_model = None
_cached_tokenizer = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.models import Message, PromptDataType, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p

Expand All @@ -20,6 +20,9 @@ class HuggingFaceEndpointTarget(PromptTarget):
Inherits from PromptTarget to comply with the current design standards.
"""

SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset(["text"])}

def __init__(
self,
*,
Expand Down
156 changes: 156 additions & 0 deletions pyrit/prompt_target/modality_verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Optional modality verification system for prompt targets.

This module provides runtime modality discovery to determine what modalities
a specific target actually supports, beyond what the API declares as possible.

Usage:
from pyrit.prompt_target.modality_verification import verify_target_modalities

# Get static API modalities
api_modalities = target.SUPPORTED_INPUT_MODALITIES

# Optionally verify actual model modalities
actual_modalities = await verify_target_modalities(target)
"""

import logging
import os
from typing import Optional

from pyrit.common.path import DATASETS_PATH
from pyrit.models import Message, MessagePiece, PromptDataType
from pyrit.prompt_target.common.prompt_target import PromptTarget

logger = logging.getLogger(__name__)

# Path to the assets directory containing test files for modality verification
_ASSETS_DIR = DATASETS_PATH / "modality_test_assets"

# Mapping from PromptDataType to test asset filenames
_TEST_ASSETS: dict[str, str] = {
"image_path": str(_ASSETS_DIR / "test_image.png"),
"audio_path": str(_ASSETS_DIR / "test_audio.wav"),
"video_path": str(_ASSETS_DIR / "test_video.mp4"),
}


async def verify_target_modalities(
target: PromptTarget,
test_modalities: Optional[set[frozenset[PromptDataType]]] = None,
) -> set[frozenset[PromptDataType]]:
"""
Verify which modality combinations a target actually supports.

This function tests the target with minimal requests to determine actual
modalities, trimming down from the static API declarations.

Args:
target: The prompt target to test
test_modalities: Specific modalities to test (defaults to target's declared modalities)

Returns:
Set of actually supported input modality combinations

Example:
actual = await verify_target_modalities(openai_target)
# Returns: {frozenset(["text"])} or {frozenset(["text"]), frozenset(["text", "image_path"])}
"""
if test_modalities is None:
test_modalities = target.SUPPORTED_INPUT_MODALITIES

verified_modalities: set[frozenset[PromptDataType]] = set()

for modality_combination in test_modalities:
try:
is_supported = await _test_modality_combination(target, modality_combination)
if is_supported:
verified_modalities.add(modality_combination)
except Exception as e:
logger.info(f"Failed to verify {modality_combination}: {e}")

return verified_modalities


async def _test_modality_combination(
target: PromptTarget,
modalities: frozenset[PromptDataType],
) -> bool:
"""
Test a specific modality combination with a minimal API request.

Args:
target: The target to test
modalities: The combination of modalities to test

Returns:
True if the combination is supported, False otherwise
"""
test_message = _create_test_message(modalities)

try:
responses = await target.send_prompt_async(message=test_message)

# Check if the response itself indicates an error
for response in responses:
for piece in response.message_pieces:
if piece.response_error != "none":
logger.info(f"Modality {modalities} returned error response: {piece.converted_value}")
return False

return True

except Exception as e:
logger.info(f"Modality {modalities} not supported: {e}")
return False


def _create_test_message(modalities: frozenset[PromptDataType]) -> Message:
"""
Create a minimal test message for the specified modalities.

Args:
modalities: The modalities to include in the test message

Returns:
A Message object with minimal content for each requested modality

Raises:
FileNotFoundError: If a required test asset file is missing
ValueError: If a modality has no configured test asset or no pieces could be created
"""
pieces: list[MessagePiece] = []
conversation_id = "modality-verification-test"

for modality in modalities:
if modality == "text":
pieces.append(
MessagePiece(
role="user",
original_value="test",
original_value_data_type="text",
conversation_id=conversation_id,
)
)
elif modality in _TEST_ASSETS:
asset_path = _TEST_ASSETS[modality]
if not os.path.isfile(asset_path):
raise FileNotFoundError(f"Test asset not found for modality '{modality}': {asset_path}")
pieces.append(
MessagePiece(
role="user",
original_value=asset_path,
original_value_data_type=modality,
conversation_id=conversation_id,
)
)
else:
raise ValueError(f"No test asset configured for modality: {modality}")

if not pieces:
raise ValueError(f"Could not create test message for modalities: {modalities}")

return Message(pieces)
Loading