From 99ee74985846521052b92dc9bacf0cdb0bdd0a6d Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Fri, 30 Jan 2026 10:15:33 -0500 Subject: [PATCH] Fix AAD authentication in ACA environments for RedTeam When running red team scans in ACA environments, DefaultAzureCredential fails because Azure CLI is not available. This fix allows passing an explicit credential from the RedTeam instance to get_chat_target(). Changes: - Add credential parameter to get_chat_target() function - Add _create_token_provider() helper to create callable from credential - Pass self.credential from RedTeam.scan() to get_chat_target() - Add credential field to AzureOpenAIModelConfiguration TypedDict - Add tests for credential-based authentication paths Auth priority: api_key > credential > use_aad_auth (DefaultAzureCredential) --- .../azure/ai/evaluation/red_team/_red_team.py | 73 +++++++-- .../red_team/_utils/strategy_utils.py | 43 ++++-- .../test_redteam/test_strategy_utils.py | 142 +++++++++++++++++- 3 files changed, 227 insertions(+), 31 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py index 704ee535024f..aff8b174d879 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py @@ -24,7 +24,9 @@ ) # TODO: uncomment when app insights checked in from azure.ai.evaluation._model_configurations import EvaluationResult from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager -from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient +from azure.ai.evaluation.simulator._model_tools._generated_rai_client import ( + GeneratedRAIClient, +) from azure.ai.evaluation._user_agent import UserAgentSingleton from azure.ai.evaluation._model_configurations import ( AzureOpenAIModelConfiguration, @@ -59,7 +61,11 @@ from pyrit.prompt_target import PromptChatTarget # Local imports - constants and utilities -from ._utils.constants import TASK_STATUS, MAX_SAMPLING_ITERATIONS_MULTIPLIER, RISK_TO_NUM_SUBTYPE_MAP +from ._utils.constants import ( + TASK_STATUS, + MAX_SAMPLING_ITERATIONS_MULTIPLIER, + RISK_TO_NUM_SUBTYPE_MAP, +) from ._utils.logging_utils import ( setup_logger, log_section_header, @@ -205,7 +211,8 @@ def __init__( # Initialize RAI client self.generated_rai_client = GeneratedRAIClient( - azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.credential + azure_ai_project=self.azure_ai_project, + token_manager=self.token_manager.credential, ) # Initialize a cache for attack objectives by risk category and strategy @@ -386,7 +393,12 @@ async def _get_attack_objectives( if custom_objectives: # Use custom objectives for this risk category return await self._get_custom_attack_objectives( - risk_cat_value, num_objectives, num_objectives_with_subtypes, strategy, current_key, is_agent_target + risk_cat_value, + num_objectives, + num_objectives_with_subtypes, + strategy, + current_key, + is_agent_target, ) else: # No custom objectives for this risk category, but risk_categories was specified @@ -574,7 +586,13 @@ async def _get_custom_attack_objectives( self.prompt_to_risk_subtype[content] = risk_subtype # Store in cache and return - self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives) + self._cache_attack_objectives( + current_key, + risk_cat_value, + strategy, + selected_prompts, + selected_cat_objectives, + ) return selected_prompts async def _get_rai_attack_objectives( @@ -680,12 +698,22 @@ async def _get_rai_attack_objectives( # Filter and select objectives using num_objectives_with_subtypes selected_cat_objectives = self._filter_and_select_objectives( - objectives_response, strategy, baseline_objectives_exist, baseline_key, num_objectives_with_subtypes + objectives_response, + strategy, + baseline_objectives_exist, + baseline_key, + num_objectives_with_subtypes, ) # Extract content and cache selected_prompts = self._extract_objective_content(selected_cat_objectives) - self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives) + self._cache_attack_objectives( + current_key, + risk_cat_value, + strategy, + selected_prompts, + selected_cat_objectives, + ) return selected_prompts @@ -820,7 +848,11 @@ async def get_xpia_prompts_with_retry(): # Build the contexts list: XPIA context + any baseline contexts with agent fields contexts = [ - {"content": formatted_context, "context_type": context_type, "tool_name": tool_name} + { + "content": formatted_context, + "context_type": context_type, + "tool_name": tool_name, + } ] # Add baseline contexts with agent fields as separate context entries @@ -1362,10 +1394,13 @@ async def scan( # Fetch attack objectives all_objectives = await self._fetch_all_objectives( - flattened_attack_strategies, application_scenario, is_agent_target, client_id + flattened_attack_strategies, + application_scenario, + is_agent_target, + client_id, ) - chat_target = get_chat_target(target) + chat_target = get_chat_target(target, credential=self.credential) self.chat_target = chat_target # Execute attacks @@ -1481,7 +1516,10 @@ async def _fetch_all_objectives( # Calculate and log num_objectives_with_subtypes once globally num_objectives = self.attack_objective_generator.num_objectives - max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in self.risk_categories), default=0) + max_num_subtypes = max( + (RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in self.risk_categories), + default=0, + ) num_objectives_with_subtypes = max(num_objectives, max_num_subtypes) if num_objectives_with_subtypes != num_objectives: @@ -1594,7 +1632,11 @@ async def _execute_attacks( progress_bar.close() async def _process_orchestrator_tasks( - self, orchestrator_tasks: List, parallel_execution: bool, max_parallel_tasks: int, timeout: int + self, + orchestrator_tasks: List, + parallel_execution: bool, + max_parallel_tasks: int, + timeout: int, ): """Process orchestrator tasks either in parallel or sequentially.""" if parallel_execution and orchestrator_tasks: @@ -1629,7 +1671,12 @@ async def _process_orchestrator_tasks( continue async def _finalize_results( - self, skip_upload: bool, skip_evals: bool, eval_run, output_path: str, scan_name: str + self, + skip_upload: bool, + skip_evals: bool, + eval_run, + output_path: str, + scan_name: str, ) -> RedTeamResult: """Process and finalize scan results.""" log_section_header(self.logger, "Processing results") diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/strategy_utils.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/strategy_utils.py index 26455d3031f7..ce789b6d2770 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/strategy_utils.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/strategy_utils.py @@ -6,7 +6,9 @@ from typing import Dict, List, Union, Optional, Any, Callable, cast import logging -from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient +from azure.ai.evaluation.simulator._model_tools._generated_rai_client import ( + GeneratedRAIClient, +) from .._attack_strategy import AttackStrategy from pyrit.prompt_converter import ( PromptConverter, @@ -35,8 +37,10 @@ from .._default_converter import _DefaultConverter from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget from .._callback_chat_target import _CallbackChatTarget -from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration - +from azure.ai.evaluation._model_configurations import ( + AzureOpenAIModelConfiguration, + OpenAIModelConfiguration, +) # Azure OpenAI uses cognitive services scope for AAD authentication AZURE_OPENAI_SCOPE = "https://cognitiveservices.azure.com/.default" @@ -65,7 +69,9 @@ def get_token() -> str: def create_tense_converter( - generated_rai_client: GeneratedRAIClient, is_one_dp_project: bool, logger: logging.Logger + generated_rai_client: GeneratedRAIClient, + is_one_dp_project: bool, + logger: logging.Logger, ) -> TenseConverter: """Factory function for creating TenseConverter with proper dependencies.""" converter_target = AzureRAIServiceTarget( @@ -141,12 +147,22 @@ def _resolve_converter(strategy): def get_chat_target( - target: Union[PromptChatTarget, Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration], + target: Union[ + PromptChatTarget, + Callable, + AzureOpenAIModelConfiguration, + OpenAIModelConfiguration, + ], + credential: Optional[Any] = None, ) -> PromptChatTarget: """Convert various target types to a PromptChatTarget. :param target: The target to convert :type target: Union[PromptChatTarget, Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration] + :param credential: Optional credential object with get_token method for AAD authentication. + Used as a fallback when target doesn't have an api_key or credential field. This is useful + in ACA environments where DefaultAzureCredential is not available. + :type credential: Optional[Any] :return: A PromptChatTarget instance :rtype: PromptChatTarget """ @@ -166,9 +182,9 @@ def _message_to_dict(message): if not isinstance(target, Callable): if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI api_key = target.get("api_key", None) - credential = target.get("credential", None) api_version = target.get("api_version", "2024-06-01") - + # Check for credential in target dict or use passed credential parameter + target_credential = target.get("credential", None) or credential if api_key: # Use API key authentication chat_target = OpenAIChatTarget( @@ -177,13 +193,13 @@ def _message_to_dict(message): api_key=api_key, api_version=api_version, ) - elif credential: + elif target_credential: # Use explicit TokenCredential for AAD auth (e.g., in ACA environments) - token_provider = _create_token_provider(credential) + token_provider = _create_token_provider(target_credential) chat_target = OpenAIChatTarget( model_name=target["azure_deployment"], endpoint=target["azure_endpoint"], - api_key=token_provider, # Callable that returns tokens + api_key=token_provider, # PyRIT accepts callable that returns token api_version=api_version, ) else: @@ -252,7 +268,12 @@ async def callback_target( "context": {}, } messages_list.append(formatted_response) # type: ignore - return {"messages": messages_list, "stream": stream, "session_state": session_state, "context": {}} + return { + "messages": messages_list, + "stream": stream, + "session_state": session_state, + "context": {}, + } chat_target = _CallbackChatTarget(callback=callback_target) # type: ignore diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_redteam/test_strategy_utils.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_redteam/test_strategy_utils.py index 41e6d75a454e..793e04c681ab 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_redteam/test_strategy_utils.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_redteam/test_strategy_utils.py @@ -15,7 +15,12 @@ ) from azure.ai.evaluation.red_team._attack_strategy import AttackStrategy from azure.ai.evaluation.red_team._callback_chat_target import _CallbackChatTarget -from pyrit.prompt_converter import PromptConverter, Base64Converter, FlipConverter, MorseConverter +from pyrit.prompt_converter import ( + PromptConverter, + Base64Converter, + FlipConverter, + MorseConverter, +) from pyrit.prompt_target import PromptChatTarget, OpenAIChatTarget initialize_pyrit(memory_db_type=IN_MEMORY) @@ -119,7 +124,10 @@ def test_get_chat_target_azure_openai(self, mock_openai_chat_target): mock_openai_chat_target.reset_mock() # Test with AAD auth - config = {"azure_deployment": "gpt-35-turbo", "azure_endpoint": "https://example.openai.azure.com"} + config = { + "azure_deployment": "gpt-35-turbo", + "azure_endpoint": "https://example.openai.azure.com", + } result = get_chat_target(config) @@ -131,8 +139,8 @@ def test_get_chat_target_azure_openai(self, mock_openai_chat_target): ) @patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget") - def test_get_chat_target_azure_openai_with_credential(self, mock_openai_chat_target): - """Test getting chat target from an Azure OpenAI configuration with TokenCredential.""" + def test_get_chat_target_azure_openai_with_credential_in_target(self, mock_openai_chat_target): + """Test getting chat target from an Azure OpenAI configuration with credential in target dict.""" mock_instance = MagicMock() mock_openai_chat_target.return_value = mock_instance @@ -167,6 +175,116 @@ def test_get_chat_target_azure_openai_with_credential(self, mock_openai_chat_tar assert result == mock_instance + @patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget") + def test_get_chat_target_azure_openai_with_credential_parameter(self, mock_openai_chat_target): + """Test getting chat target with credential passed as parameter (for ACA environments).""" + mock_instance = MagicMock() + mock_openai_chat_target.return_value = mock_instance + + # Create a mock credential that behaves like TokenCredential + mock_credential = MagicMock() + mock_token = MagicMock() + mock_token.token = "test-access-token" + mock_credential.get_token.return_value = mock_token + + # Config without api_key or credential + config = { + "azure_deployment": "gpt-35-turbo", + "azure_endpoint": "https://example.openai.azure.com", + } + + # Pass credential as parameter (this is how RedTeam.scan() passes it) + result = get_chat_target(config, credential=mock_credential) + + # Verify OpenAIChatTarget was called with a callable for api_key + mock_openai_chat_target.assert_called_once() + call_kwargs = mock_openai_chat_target.call_args[1] + assert call_kwargs["model_name"] == "gpt-35-turbo" + assert call_kwargs["endpoint"] == "https://example.openai.azure.com" + assert call_kwargs["api_version"] == "2024-06-01" + # api_key should be a callable (token provider) + assert callable(call_kwargs["api_key"]) + + # Verify the token provider returns the expected token + token_provider = call_kwargs["api_key"] + token = token_provider() + assert token == "test-access-token" + mock_credential.get_token.assert_called_with("https://cognitiveservices.azure.com/.default") + + assert result == mock_instance + + @patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget") + def test_get_chat_target_azure_openai_api_key_takes_precedence(self, mock_openai_chat_target): + """Test that api_key takes precedence over credential when both are provided.""" + mock_instance = MagicMock() + mock_openai_chat_target.return_value = mock_instance + + mock_credential = MagicMock() + + config = { + "azure_deployment": "gpt-35-turbo", + "azure_endpoint": "https://example.openai.azure.com", + "api_key": "test-api-key", + "credential": mock_credential, + } + + result = get_chat_target(config) + + # Should use api_key, not credential + mock_openai_chat_target.assert_called_once_with( + model_name="gpt-35-turbo", + endpoint="https://example.openai.azure.com", + api_key="test-api-key", + api_version="2024-06-01", + ) + # Credential should not be used + mock_credential.get_token.assert_not_called() + + assert result == mock_instance + + @patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget") + def test_get_chat_target_azure_openai_target_credential_takes_precedence_over_parameter( + self, mock_openai_chat_target + ): + """Test that target['credential'] takes precedence over credential parameter.""" + mock_instance = MagicMock() + mock_openai_chat_target.return_value = mock_instance + + # Create two different mock credentials + target_credential = MagicMock() + target_token = MagicMock() + target_token.token = "target-credential-token" + target_credential.get_token.return_value = target_token + + param_credential = MagicMock() + param_token = MagicMock() + param_token.token = "param-credential-token" + param_credential.get_token.return_value = param_token + + # Config with credential in target dict + config = { + "azure_deployment": "gpt-35-turbo", + "azure_endpoint": "https://example.openai.azure.com", + "credential": target_credential, + } + + # Pass different credential as parameter + result = get_chat_target(config, credential=param_credential) + + # Verify OpenAIChatTarget was called + mock_openai_chat_target.assert_called_once() + call_kwargs = mock_openai_chat_target.call_args[1] + + # Verify the token provider uses target_credential, not param_credential + token_provider = call_kwargs["api_key"] + token = token_provider() + assert token == "target-credential-token" + target_credential.get_token.assert_called_with("https://cognitiveservices.azure.com/.default") + # param_credential should not be used + param_credential.get_token.assert_not_called() + + assert result == mock_instance + @patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget") def test_get_chat_target_openai(self, mock_openai_chat_target): """Test getting chat target from an OpenAI configuration.""" @@ -178,18 +296,28 @@ def test_get_chat_target_openai(self, mock_openai_chat_target): result = get_chat_target(config) mock_openai_chat_target.assert_called_once_with( - model_name="gpt-4", endpoint=None, api_key="test-api-key", api_version="2024-06-01" + model_name="gpt-4", + endpoint=None, + api_key="test-api-key", + api_version="2024-06-01", ) # Test with base_url mock_openai_chat_target.reset_mock() - config = {"model": "gpt-4", "api_key": "test-api-key", "base_url": "https://example.com/api"} + config = { + "model": "gpt-4", + "api_key": "test-api-key", + "base_url": "https://example.com/api", + } result = get_chat_target(config) mock_openai_chat_target.assert_called_once_with( - model_name="gpt-4", endpoint="https://example.com/api", api_key="test-api-key", api_version="2024-06-01" + model_name="gpt-4", + endpoint="https://example.com/api", + api_key="test-api-key", + api_version="2024-06-01", ) @patch("azure.ai.evaluation.red_team._utils.strategy_utils._CallbackChatTarget")