Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 91 additions & 25 deletions pyrit/executor/attack/multi_turn/red_teaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import enum
import logging
import uuid
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional

from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults
from pyrit.common.path import EXECUTOR_RED_TEAM_PATH
Expand All @@ -33,6 +34,7 @@
ConversationReference,
ConversationType,
Message,
MessagePiece,
Score,
SeedPrompt,
)
Expand Down Expand Up @@ -355,12 +357,21 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any]
# Generate prompt using adversarial chat
logger.debug(f"Generating prompt for turn {context.executed_turns + 1}")

# Prepare prompt for the adversarial chat
prompt_text = await self._build_adversarial_prompt(context)

# Send the prompt to the adversarial chat and get the response
logger.debug(f"Sending prompt to adversarial chat: {prompt_text[:50]}...")
prompt_message = Message.from_prompt(prompt=prompt_text, role="user")
# Build the message for the adversarial chat
prompt_message = await self._build_adversarial_prompt(context)

# Log the message being sent
if prompt_message.is_multimodal():
text_piece = prompt_message.get_first_piece_by_data_type("text")
media_pieces = [p for p in prompt_message.message_pieces if p.converted_value_data_type != "text"]
feedback_text = text_piece.converted_value if text_piece else "No text content"
media_info = f"{len(media_pieces)} media piece(s)" if media_pieces else "no media"
logger.debug(
f"Sending multimodal prompt to adversarial chat: {feedback_text[:50]}... + {media_info}"
)
else:
prompt_text = prompt_message.get_first_piece().converted_value
logger.debug(f"Sending prompt to adversarial chat: {prompt_text[:50]}...")

with execution_context(
component_role=ComponentRole.ADVERSARIAL_CHAT,
Expand Down Expand Up @@ -388,32 +399,35 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any]
async def _build_adversarial_prompt(
self,
context: MultiTurnAttackContext[Any],
) -> str:
) -> Message:
"""
Build a prompt for the adversarial chat based on the last response.
Build a prompt message for the adversarial chat based on the last response.

For text responses, creates a simple text message. For file/media responses (images, video, etc.),
creates a multimodal message that includes both the textual feedback and the actual generated
media so the adversarial chat can see what the target produced.

Args:
context (MultiTurnAttackContext): The attack context containing the current state and configuration.

Returns:
str: The prompt to be sent to the adversarial chat.
Message: A message ready to be sent to the adversarial chat.
"""
# If no last response, return the seed prompt (rendered with objective if template exists)
if not context.last_response:
return self._adversarial_chat_seed_prompt.render_template_value_silent(objective=context.objective)
prompt_text = self._adversarial_chat_seed_prompt.render_template_value_silent(objective=context.objective)
return Message.from_prompt(prompt=prompt_text, role="user")

# Get the last assistant piece from the response
response_piece = context.last_response.get_piece()

# Delegate to appropriate handler based on data type
handlers = {
"text": self._handle_adversarial_text_response,
"error": self._handle_adversarial_text_response,
}

handler = handlers.get(response_piece.converted_value_data_type, self._handle_adversarial_file_response)

return handler(context=context)
# Build message based on response type (text vs file/media)
if response_piece.converted_value_data_type in ("text", "error"):
feedback_text = self._handle_adversarial_text_response(context=context)
return self._build_text_message(feedback_text)
else:
feedback_text, media_piece = self._handle_adversarial_file_response(context=context)
return self._build_multimodal_message(feedback_text, media_piece)

def _handle_adversarial_text_response(self, *, context: MultiTurnAttackContext[Any]) -> str:
"""
Expand Down Expand Up @@ -450,25 +464,34 @@ def _handle_adversarial_text_response(self, *, context: MultiTurnAttackContext[A

return f"Request to target failed: {response_piece.response_error}"

def _handle_adversarial_file_response(self, *, context: MultiTurnAttackContext[Any]) -> str:
def _handle_adversarial_file_response(
self, *, context: MultiTurnAttackContext[Any]
) -> tuple[str, Optional[MessagePiece]]:
"""
Handle the file response from the target.

Returns the scoring feedback text along with the media piece from the target's response,
enabling the adversarial chat to receive a multimodal message with both the textual feedback
and the actual generated media (image, video, etc.) for more informed prompt generation.

If the response indicates an error, raise a RuntimeError. When scoring is disabled or no
scoring rationale is provided, raise a ValueError. Otherwise, return the textual feedback as the prompt.
scoring rationale is provided, raise a ValueError. Otherwise, return the textual feedback
and the media piece as a tuple.

Args:
context (MultiTurnAttackContext): The attack context containing the response and score.

Returns:
str: The suitable feedback or error message to pass back to the adversarial chat.
tuple[str, Optional[MessagePiece]]: A tuple of (feedback_text, media_piece).
The media_piece is the response piece from the target containing the generated media,
or None if no response is available.

Raises:
RuntimeError: If the target response indicates an error.
ValueError: If scoring is disabled or no scoring rationale is available.
"""
if not context.last_response:
return "No response available. Please continue."
return ("No response available. Please continue.", None)

response_piece = context.last_response.get_piece()

Expand All @@ -494,7 +517,50 @@ def _handle_adversarial_file_response(self, *, context: MultiTurnAttackContext[A
"However, no scoring rationale was provided by the scorer."
)

return feedback
return (feedback, response_piece)

def _build_text_message(self, feedback_text: str) -> Message:
"""
Build a simple text message for the adversarial chat.

Args:
feedback_text (str): The text content for the message.

Returns:
Message: A text message ready to be sent to the adversarial chat.
"""
return Message.from_prompt(prompt=feedback_text, role="user")

def _build_multimodal_message(self, feedback_text: str, media_piece: Optional[MessagePiece]) -> Message:
"""
Build a multimodal message for the adversarial chat containing both text and media.

Args:
feedback_text (str): The textual feedback to include.
media_piece (Optional[MessagePiece]): The media piece from the target response, if any.

Returns:
Message: A multimodal message ready to be sent to the adversarial chat.
"""
# Use a shared conversation_id so Message validation passes
shared_conversation_id = str(uuid.uuid4())
pieces = [
MessagePiece(
original_value=feedback_text,
role="user",
conversation_id=shared_conversation_id,
)
]
if media_piece is not None:
pieces.append(
MessagePiece(
original_value=media_piece.converted_value,
role="user",
original_value_data_type=media_piece.converted_value_data_type,
conversation_id=shared_conversation_id,
)
)
return Message(message_pieces=pieces)

async def _send_prompt_to_objective_target_async(
self,
Expand Down
2 changes: 1 addition & 1 deletion pyrit/prompt_target/openai/openai_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ async def _build_chat_messages_for_multi_modal_async(
):
continue

if message_piece.converted_value_data_type == "text":
if message_piece.converted_value_data_type in ("text", "error"):
entry = {"type": "text", "text": message_piece.converted_value}
content.append(entry)
elif message_piece.converted_value_data_type == "image_path":
Expand Down
Loading