Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions pyrit/common/modality_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import asyncio
import logging
from typing import Any, Dict, List, Optional

from pyrit.models.literals import PromptDataType

logger = logging.getLogger(__name__)


def verify_target_capabilities(
target: Any,
test_modalities: Optional[List[str]] = None
) -> Dict[str, bool]:
"""
Optional utility to verify actual modality capabilities of a target via runtime testing.

This is a verification tool that can be used to check if declared capabilities
match actual API behavior. It's completely optional and separate from the main
modality support system.

Args:
target: The target to test (must have _async_client and model_name)
test_modalities: List of modalities to test (defaults to ["image", "audio"])

Returns:
Dict mapping modality names to actual support status
"""
if test_modalities is None:
test_modalities = ["image", "audio"]

if not hasattr(target, '_async_client') or not hasattr(target, 'model_name'):
logger.warning(f"Target {type(target).__name__} doesn't support capability testing")
return {modality: False for modality in test_modalities}

capabilities = {}

try:
for modality in test_modalities:
capabilities[modality] = _test_single_modality(target, modality)
except Exception as e:
logger.warning(f"Capability verification failed for {target.model_name}: {e}")
return {modality: False for modality in test_modalities}

logger.info(f"Verified capabilities for {target.model_name}: {capabilities}")
return capabilities


def _test_single_modality(target: Any, modality: str) -> bool:
"""Test a single modality with minimal API request."""
try:
if modality == "image":
# Minimal 1x1 PNG test
minimal_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
test_messages = [{
"role": "user",
"content": [
{"type": "text", "text": "Test image?"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{minimal_png}"}}
]
}]
elif modality == "audio":
# Minimal audio test (placeholder)
test_messages = [{
"role": "user",
"content": [{"type": "text", "text": "Test audio?"}]
}]
else:
return False

# Try the request
async def _test():
try:
await target._async_client.chat.completions.create(
model=target.model_name,
messages=test_messages,
max_tokens=1
)
return True
except Exception as e:
error_msg = str(e).lower()
# Check for modality-specific errors
if any(indicator in error_msg for indicator in [
"does not support", "not supported", "vision is not supported", "invalid content type"
]):
return False
return False # Any error = not supported

# Run the async test
try:
loop = asyncio.get_running_loop()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, _test())
return future.result(timeout=10)
except RuntimeError:
return asyncio.run(_test())

except Exception:
return False
3 changes: 3 additions & 0 deletions pyrit/identifiers/target_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class TargetIdentifier(Identifier):
target_specific_params: Optional[Dict[str, Any]] = None
"""Additional target-specific parameters."""

supports_conversation_history: bool = True
"""Whether the target supports maintaining conversation history."""

@classmethod
def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIdentifier":
"""
Expand Down
57 changes: 57 additions & 0 deletions pyrit/prompt_target/common/prompt_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyrit.identifiers import Identifiable, TargetIdentifier
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import Message
from pyrit.models.literals import PromptDataType

logger = logging.getLogger(__name__)

Expand All @@ -26,6 +27,16 @@ class PromptTarget(Identifiable[TargetIdentifier]):
#: An empty list implies that the prompt target supports all converters.
supported_converters: List[Any]

#: Set of supported input modality combinations. Each frozenset represents a valid
#: combination of modalities that can be sent together in a single request.
#: Example: {frozenset({"text"}), frozenset({"text", "image_path"})} means supports text-only OR text+image
SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset({"text"})}

#: Set of supported output modality combinations. Each frozenset represents a valid
#: combination of modalities that can be returned together in a single response.
#: Example: {frozenset({"text"})} means produces text-only outputs
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {frozenset({"text"})}

_identifier: Optional[TargetIdentifier] = None

def __init__(
Expand Down Expand Up @@ -152,3 +163,49 @@ def _build_identifier(self) -> TargetIdentifier:
TargetIdentifier: The identifier for this prompt target.
"""
return self._create_identifier()

def input_modality_supported(self, modalities: set[PromptDataType]) -> bool:
"""
Check if a specific combination of input modalities is supported by this target.

Args:
modalities: The set of modalities to check together (e.g., {"text", "image_path"}).

Returns:
bool: True if the exact combination is supported, False otherwise.
"""
modality_frozenset = frozenset(modalities)
supported_modalities = self.SUPPORTED_INPUT_MODALITIES # Works with both class attr and property
return modality_frozenset in supported_modalities

def output_modality_supported(self, modalities: set[PromptDataType]) -> bool:
"""
Check if a specific combination of output modalities is supported by this target.

Args:
modalities: The set of modalities to check together (e.g., {"text", "image_url"}).

Returns:
bool: True if the exact combination is supported, False otherwise.
"""
return frozenset(modalities) in self.SUPPORTED_OUTPUT_MODALITIES

@property
def supported_input_modalities(self) -> set[PromptDataType]:
"""
Get all individual input modalities supported by this target across all combinations.

Returns:
set[PromptDataType]: Set of all individual modalities that appear in any supported combination.
"""
return set.union(*self.SUPPORTED_INPUT_MODALITIES) if self.SUPPORTED_INPUT_MODALITIES else set()

@property
def supported_output_modalities(self) -> set[PromptDataType]:
"""
Get all individual output modalities supported by this target across all combinations.

Returns:
set[PromptDataType]: Set of all individual modalities that appear in any supported combination.
"""
return set.union(*self.SUPPORTED_OUTPUT_MODALITIES) if self.SUPPORTED_OUTPUT_MODALITIES else set()
13 changes: 13 additions & 0 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import asyncio
import logging
import os
Expand All @@ -19,6 +21,7 @@
from pyrit.exceptions import EmptyResponseException, pyrit_target_retry
from pyrit.identifiers import TargetIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.models.literals import PromptDataType
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute

Expand All @@ -34,6 +37,16 @@ class HuggingFaceChatTarget(PromptChatTarget):
Inherits from PromptTarget to comply with the current design standards.
"""

#: HuggingFace Chat targets support only text input and output
SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {
frozenset({"text"})
}

#: HuggingFace Chat targets produce only text outputs
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {
frozenset({"text"})
}

# Class-level cache for model and tokenizer
_cached_model = None
_cached_tokenizer = None
Expand Down
95 changes: 95 additions & 0 deletions pyrit/prompt_target/openai/openai_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import asyncio
import base64
import concurrent.futures
import json
import logging
from typing import Any, Dict, MutableSequence, Optional
Expand All @@ -22,6 +26,7 @@
data_serializer_factory,
)
from pyrit.models.json_response_config import _JsonResponseConfig
from pyrit.models.literals import PromptDataType
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p
from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig
Expand Down Expand Up @@ -62,6 +67,96 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget):

"""

@property
def SUPPORTED_INPUT_MODALITIES(self) -> set[frozenset[PromptDataType]]:
"""
Get supported input modalities based on the OpenAI model.

Uses intelligent pattern matching with fallback verification for unknown models.
For explicit verification, use verify_capabilities() method.
"""
model_lower = self.model_name.lower()

# Text-only patterns (explicit)
if any(pattern in model_lower for pattern in [
"gpt-3.5", "davinci", "curie", "babbage", "ada"
]):
return {frozenset({"text"})}

# Multimodal indicators (vision/image support)
multimodal_indicators = [
"vision", # gpt-4-vision-preview, etc.
"gpt-4o", # gpt-4o, gpt-4o-mini, gpt-4o-2024-*, etc.
"gpt-4-turbo", # Often has vision
"gpt-5", # Future models likely multimodal
"gpt-4.5", # Hypothetical intermediate
"multimodal", # Explicit in name
"omni", # Omni-modal models
]

if any(indicator in model_lower for indicator in multimodal_indicators):
return {
frozenset({"text"}),
frozenset({"text", "image_path"})
}

# For unknown GPT-4+ models, assume multimodal (safer for newer models)
if "gpt-4" in model_lower and not any(old_pattern in model_lower for old_pattern in ["gpt-4-0314", "gpt-4-32k"]):
return {
frozenset({"text"}),
frozenset({"text", "image_path"})
}

# Conservative default for completely unknown models
return {frozenset({"text"})}

@property
def SUPPORTED_OUTPUT_MODALITIES(self) -> set[frozenset[PromptDataType]]:
"""OpenAI models currently only produce text outputs."""
return {frozenset({"text"})}

def verify_capabilities(self, use_as_fallback: bool = False) -> dict[str, bool]:
"""
Optional verification of actual capabilities via runtime testing.

Args:
use_as_fallback: If True, updates SUPPORTED_INPUT_MODALITIES cache based on results

Returns:
Dict mapping modality names to actual support status
"""
from pyrit.common.modality_discovery import verify_target_capabilities
capabilities = verify_target_capabilities(self)

# Optional: Cache verified capabilities for unknown models
if use_as_fallback and not hasattr(self, '_verified_modalities'):
self._verified_modalities = capabilities

return capabilities

def get_verified_input_modalities(self) -> set[frozenset[PromptDataType]]:
"""
Get input modalities using runtime verification as fallback for unknown models.
More accurate but slower than static SUPPORTED_INPUT_MODALITIES.
"""
# Use cached verification if available
if hasattr(self, '_verified_modalities'):
capabilities = self._verified_modalities
else:
capabilities = self.verify_capabilities(use_as_fallback=True)

# Convert capabilities to modality sets
modalities = {frozenset({"text"})} # Always support text

if capabilities.get("image", False):
modalities.add(frozenset({"text", "image_path"}))
if capabilities.get("audio", False):
modalities.add(frozenset({"text", "audio_path"}))
if capabilities.get("video", False):
modalities.add(frozenset({"text", "video_path"}))

return modalities

def __init__(
self,
*,
Expand Down
13 changes: 13 additions & 0 deletions pyrit/prompt_target/text_target.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import csv
import json
import sys
from pathlib import Path
from typing import IO

from pyrit.models import Message, MessagePiece
from pyrit.models.literals import PromptDataType
from pyrit.prompt_target.common.prompt_target import PromptTarget


Expand All @@ -20,6 +23,16 @@ class TextTarget(PromptTarget):
but enter them manually.
"""

#: TextTarget supports only text input and output
SUPPORTED_INPUT_MODALITIES: set[frozenset[PromptDataType]] = {
frozenset({"text"})
}

#: TextTarget produces only text outputs
SUPPORTED_OUTPUT_MODALITIES: set[frozenset[PromptDataType]] = {
frozenset({"text"})
}

def __init__(
self,
*,
Expand Down
Loading