diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 33b2c75d7..cf21a0d76 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -5,8 +5,9 @@ import enum import logging +import uuid from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_RED_TEAM_PATH @@ -33,6 +34,7 @@ ConversationReference, ConversationType, Message, + MessagePiece, Score, SeedPrompt, ) @@ -355,12 +357,21 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] # Generate prompt using adversarial chat logger.debug(f"Generating prompt for turn {context.executed_turns + 1}") - # Prepare prompt for the adversarial chat - prompt_text = await self._build_adversarial_prompt(context) - - # Send the prompt to the adversarial chat and get the response - logger.debug(f"Sending prompt to adversarial chat: {prompt_text[:50]}...") - prompt_message = Message.from_prompt(prompt=prompt_text, role="user") + # Build the message for the adversarial chat + prompt_message = await self._build_adversarial_prompt(context) + + # Log the message being sent + if prompt_message.is_multimodal(): + text_piece = prompt_message.get_first_piece_by_data_type("text") + media_pieces = [p for p in prompt_message.message_pieces if p.converted_value_data_type != "text"] + feedback_text = text_piece.converted_value if text_piece else "No text content" + media_info = f"{len(media_pieces)} media piece(s)" if media_pieces else "no media" + logger.debug( + f"Sending multimodal prompt to adversarial chat: {feedback_text[:50]}... + {media_info}" + ) + else: + prompt_text = prompt_message.get_first_piece().converted_value + logger.debug(f"Sending prompt to adversarial chat: {prompt_text[:50]}...") with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, @@ -388,32 +399,35 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] async def _build_adversarial_prompt( self, context: MultiTurnAttackContext[Any], - ) -> str: + ) -> Message: """ - Build a prompt for the adversarial chat based on the last response. + Build a prompt message for the adversarial chat based on the last response. + + For text responses, creates a simple text message. For file/media responses (images, video, etc.), + creates a multimodal message that includes both the textual feedback and the actual generated + media so the adversarial chat can see what the target produced. Args: context (MultiTurnAttackContext): The attack context containing the current state and configuration. Returns: - str: The prompt to be sent to the adversarial chat. + Message: A message ready to be sent to the adversarial chat. """ # If no last response, return the seed prompt (rendered with objective if template exists) if not context.last_response: - return self._adversarial_chat_seed_prompt.render_template_value_silent(objective=context.objective) + prompt_text = self._adversarial_chat_seed_prompt.render_template_value_silent(objective=context.objective) + return Message.from_prompt(prompt=prompt_text, role="user") # Get the last assistant piece from the response response_piece = context.last_response.get_piece() - # Delegate to appropriate handler based on data type - handlers = { - "text": self._handle_adversarial_text_response, - "error": self._handle_adversarial_text_response, - } - - handler = handlers.get(response_piece.converted_value_data_type, self._handle_adversarial_file_response) - - return handler(context=context) + # Build message based on response type (text vs file/media) + if response_piece.converted_value_data_type in ("text", "error"): + feedback_text = self._handle_adversarial_text_response(context=context) + return self._build_text_message(feedback_text) + else: + feedback_text, media_piece = self._handle_adversarial_file_response(context=context) + return self._build_multimodal_message(feedback_text, media_piece) def _handle_adversarial_text_response(self, *, context: MultiTurnAttackContext[Any]) -> str: """ @@ -450,25 +464,34 @@ def _handle_adversarial_text_response(self, *, context: MultiTurnAttackContext[A return f"Request to target failed: {response_piece.response_error}" - def _handle_adversarial_file_response(self, *, context: MultiTurnAttackContext[Any]) -> str: + def _handle_adversarial_file_response( + self, *, context: MultiTurnAttackContext[Any] + ) -> tuple[str, Optional[MessagePiece]]: """ Handle the file response from the target. + Returns the scoring feedback text along with the media piece from the target's response, + enabling the adversarial chat to receive a multimodal message with both the textual feedback + and the actual generated media (image, video, etc.) for more informed prompt generation. + If the response indicates an error, raise a RuntimeError. When scoring is disabled or no - scoring rationale is provided, raise a ValueError. Otherwise, return the textual feedback as the prompt. + scoring rationale is provided, raise a ValueError. Otherwise, return the textual feedback + and the media piece as a tuple. Args: context (MultiTurnAttackContext): The attack context containing the response and score. Returns: - str: The suitable feedback or error message to pass back to the adversarial chat. + tuple[str, Optional[MessagePiece]]: A tuple of (feedback_text, media_piece). + The media_piece is the response piece from the target containing the generated media, + or None if no response is available. Raises: RuntimeError: If the target response indicates an error. ValueError: If scoring is disabled or no scoring rationale is available. """ if not context.last_response: - return "No response available. Please continue." + return ("No response available. Please continue.", None) response_piece = context.last_response.get_piece() @@ -494,7 +517,50 @@ def _handle_adversarial_file_response(self, *, context: MultiTurnAttackContext[A "However, no scoring rationale was provided by the scorer." ) - return feedback + return (feedback, response_piece) + + def _build_text_message(self, feedback_text: str) -> Message: + """ + Build a simple text message for the adversarial chat. + + Args: + feedback_text (str): The text content for the message. + + Returns: + Message: A text message ready to be sent to the adversarial chat. + """ + return Message.from_prompt(prompt=feedback_text, role="user") + + def _build_multimodal_message(self, feedback_text: str, media_piece: Optional[MessagePiece]) -> Message: + """ + Build a multimodal message for the adversarial chat containing both text and media. + + Args: + feedback_text (str): The textual feedback to include. + media_piece (Optional[MessagePiece]): The media piece from the target response, if any. + + Returns: + Message: A multimodal message ready to be sent to the adversarial chat. + """ + # Use a shared conversation_id so Message validation passes + shared_conversation_id = str(uuid.uuid4()) + pieces = [ + MessagePiece( + original_value=feedback_text, + role="user", + conversation_id=shared_conversation_id, + ) + ] + if media_piece is not None: + pieces.append( + MessagePiece( + original_value=media_piece.converted_value, + role="user", + original_value_data_type=media_piece.converted_value_data_type, + conversation_id=shared_conversation_id, + ) + ) + return Message(message_pieces=pieces) async def _send_prompt_to_objective_target_async( self, diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 87ffa26f4..e00515ed4 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -587,7 +587,7 @@ async def _build_chat_messages_for_multi_modal_async( ): continue - if message_piece.converted_value_data_type == "text": + if message_piece.converted_value_data_type in ("text", "error"): entry = {"type": "text", "text": message_piece.converted_value} content.append(entry) elif message_piece.converted_value_data_type == "image_path": diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index e6ed898b4..f64127c9a 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -1078,7 +1078,7 @@ def test_handle_adversarial_file_response_with_feedback( basic_context: MultiTurnAttackContext, success_score: Score, ): - """Test that file response with feedback returns score rationale.""" + """Test that file response with feedback returns score rationale and media piece.""" adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer, use_score_as_feedback=True) @@ -1098,7 +1098,10 @@ def test_handle_adversarial_file_response_with_feedback( result = attack._handle_adversarial_file_response(context=basic_context) - assert result == success_score.score_rationale + assert isinstance(result, tuple) + feedback_text, media = result + assert feedback_text == success_score.score_rationale + assert media is response_piece def test_handle_adversarial_file_response_no_response( self, @@ -1121,7 +1124,227 @@ def test_handle_adversarial_file_response_no_response( result = attack._handle_adversarial_file_response(context=basic_context) - assert result == "No response available. Please continue." + assert isinstance(result, tuple) + feedback_text, media = result + assert feedback_text == "No response available. Please continue." + assert media is None + + +@pytest.mark.usefixtures("patch_central_database") +class TestMultimodalFeedbackLoop: + """Tests for multimodal media content flowing through the adversarial feedback loop.""" + + @pytest.mark.asyncio + async def test_generate_next_prompt_sends_multimodal_message_for_image_response( + self, + mock_objective_target: MagicMock, + mock_objective_scorer: MagicMock, + mock_adversarial_chat: MagicMock, + basic_context: MultiTurnAttackContext, + success_score: Score, + ): + """Test that when the target returns an image, the adversarial chat receives + a multimodal message containing both the text feedback and the image.""" + adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) + scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer, use_score_as_feedback=True) + + attack = RedTeamingAttack( + objective_target=mock_objective_target, + attack_adversarial_config=adversarial_config, + attack_scoring_config=scoring_config, + ) + + # Simulate a target response with an image + response_piece = MagicMock(spec=MessagePiece) + response_piece.converted_value_data_type = "image_path" + response_piece.converted_value = "/path/to/generated_image.png" + response_piece.has_error.return_value = False + + basic_context.last_response = MagicMock(spec=Message) + basic_context.last_response.get_piece.return_value = response_piece + basic_context.last_score = success_score + basic_context.executed_turns = 1 # Not the first turn + + # Mock the adversarial chat response + adversarial_response = MagicMock(spec=Message) + adversarial_response.get_value.return_value = "Generate a more explicit image" + + mock_normalizer = AsyncMock(spec=PromptNormalizer) + mock_normalizer.send_prompt_async = AsyncMock(return_value=adversarial_response) + attack._prompt_normalizer = mock_normalizer + + result = await attack._generate_next_prompt_async(context=basic_context) + + # Verify the message sent to adversarial chat was multimodal + call_args = mock_normalizer.send_prompt_async.call_args + sent_message = call_args.kwargs.get("message") or call_args[1].get("message") + assert len(sent_message.message_pieces) == 2 + assert sent_message.message_pieces[0].original_value == success_score.score_rationale + assert sent_message.message_pieces[0].original_value_data_type == "text" + assert sent_message.message_pieces[1].original_value == "/path/to/generated_image.png" + assert sent_message.message_pieces[1].original_value_data_type == "image_path" + + # Verify the returned message for the objective target is text-only + assert result.get_value() == "Generate a more explicit image" + + @pytest.mark.asyncio + async def test_generate_next_prompt_sends_multimodal_message_for_video_response( + self, + mock_objective_target: MagicMock, + mock_objective_scorer: MagicMock, + mock_adversarial_chat: MagicMock, + basic_context: MultiTurnAttackContext, + success_score: Score, + ): + """Test that when the target returns a video, the adversarial chat receives + a multimodal message containing both the text feedback and the video.""" + adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) + scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer, use_score_as_feedback=True) + + attack = RedTeamingAttack( + objective_target=mock_objective_target, + attack_adversarial_config=adversarial_config, + attack_scoring_config=scoring_config, + ) + + # Simulate a target response with a video + response_piece = MagicMock(spec=MessagePiece) + response_piece.converted_value_data_type = "video_path" + response_piece.converted_value = "/path/to/generated_video.mp4" + response_piece.has_error.return_value = False + + basic_context.last_response = MagicMock(spec=Message) + basic_context.last_response.get_piece.return_value = response_piece + basic_context.last_score = success_score + basic_context.executed_turns = 1 + + adversarial_response = MagicMock(spec=Message) + adversarial_response.get_value.return_value = "Try again with different content" + + mock_normalizer = AsyncMock(spec=PromptNormalizer) + mock_normalizer.send_prompt_async = AsyncMock(return_value=adversarial_response) + attack._prompt_normalizer = mock_normalizer + + result = await attack._generate_next_prompt_async(context=basic_context) + + # Verify multimodal message with video + call_args = mock_normalizer.send_prompt_async.call_args + sent_message = call_args.kwargs.get("message") or call_args[1].get("message") + assert len(sent_message.message_pieces) == 2 + assert sent_message.message_pieces[1].original_value_data_type == "video_path" + assert sent_message.message_pieces[1].original_value == "/path/to/generated_video.mp4" + + @pytest.mark.asyncio + async def test_generate_next_prompt_text_response_stays_text_only( + self, + mock_objective_target: MagicMock, + mock_objective_scorer: MagicMock, + mock_adversarial_chat: MagicMock, + basic_context: MultiTurnAttackContext, + success_score: Score, + ): + """Test that text responses still produce text-only messages (no regression).""" + adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) + scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer, use_score_as_feedback=True) + + attack = RedTeamingAttack( + objective_target=mock_objective_target, + attack_adversarial_config=adversarial_config, + attack_scoring_config=scoring_config, + ) + + # Simulate a text response + response_piece = MagicMock(spec=MessagePiece) + response_piece.converted_value_data_type = "text" + response_piece.converted_value = "I cannot help with that" + response_piece.has_error.return_value = False + + basic_context.last_response = MagicMock(spec=Message) + basic_context.last_response.get_piece.return_value = response_piece + basic_context.last_score = success_score + basic_context.executed_turns = 1 + + adversarial_response = MagicMock(spec=Message) + adversarial_response.get_value.return_value = "Try rephrasing" + + mock_normalizer = AsyncMock(spec=PromptNormalizer) + mock_normalizer.send_prompt_async = AsyncMock(return_value=adversarial_response) + attack._prompt_normalizer = mock_normalizer + + await attack._generate_next_prompt_async(context=basic_context) + + # Verify message sent to adversarial chat is text-only (single piece) + call_args = mock_normalizer.send_prompt_async.call_args + sent_message = call_args.kwargs.get("message") or call_args[1].get("message") + assert len(sent_message.message_pieces) == 1 + assert sent_message.message_pieces[0].original_value_data_type == "text" + + @pytest.mark.asyncio + async def test_build_adversarial_prompt_returns_tuple_for_image_response( + self, + mock_objective_target: MagicMock, + mock_objective_scorer: MagicMock, + mock_adversarial_chat: MagicMock, + basic_context: MultiTurnAttackContext, + success_score: Score, + ): + """Test that _build_adversarial_prompt returns a tuple for image responses.""" + adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) + scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer, use_score_as_feedback=True) + + attack = RedTeamingAttack( + objective_target=mock_objective_target, + attack_adversarial_config=adversarial_config, + attack_scoring_config=scoring_config, + ) + + response_piece = MagicMock(spec=MessagePiece) + response_piece.converted_value_data_type = "image_path" + response_piece.converted_value = "/path/to/image.png" + response_piece.has_error.return_value = False + + basic_context.last_response = MagicMock(spec=Message) + basic_context.last_response.get_piece.return_value = response_piece + basic_context.last_score = success_score + + result = await attack._build_adversarial_prompt(basic_context) + + assert isinstance(result, tuple) + feedback_text, media = result + assert feedback_text == success_score.score_rationale + assert media is response_piece + + @pytest.mark.asyncio + async def test_build_adversarial_prompt_returns_str_for_text_response( + self, + mock_objective_target: MagicMock, + mock_objective_scorer: MagicMock, + mock_adversarial_chat: MagicMock, + basic_context: MultiTurnAttackContext, + ): + """Test that _build_adversarial_prompt returns a string for text responses.""" + adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) + scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer) + + attack = RedTeamingAttack( + objective_target=mock_objective_target, + attack_adversarial_config=adversarial_config, + attack_scoring_config=scoring_config, + ) + + response_piece = MagicMock(spec=MessagePiece) + response_piece.converted_value_data_type = "text" + response_piece.converted_value = "Hello world" + response_piece.has_error.return_value = False + + basic_context.last_response = MagicMock(spec=Message) + basic_context.last_response.get_piece.return_value = response_piece + basic_context.last_score = None + + result = await attack._build_adversarial_prompt(basic_context) + + assert isinstance(result, str) + assert result == "Hello world" @pytest.mark.usefixtures("patch_central_database")