Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
) # TODO: uncomment when app insights checked in
from azure.ai.evaluation._model_configurations import EvaluationResult
from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager
from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient
from azure.ai.evaluation.simulator._model_tools._generated_rai_client import (
GeneratedRAIClient,
)
from azure.ai.evaluation._user_agent import UserAgentSingleton
from azure.ai.evaluation._model_configurations import (
AzureOpenAIModelConfiguration,
Expand Down Expand Up @@ -59,7 +61,11 @@
from pyrit.prompt_target import PromptChatTarget

# Local imports - constants and utilities
from ._utils.constants import TASK_STATUS, MAX_SAMPLING_ITERATIONS_MULTIPLIER, RISK_TO_NUM_SUBTYPE_MAP
from ._utils.constants import (
TASK_STATUS,
MAX_SAMPLING_ITERATIONS_MULTIPLIER,
RISK_TO_NUM_SUBTYPE_MAP,
)
from ._utils.logging_utils import (
setup_logger,
log_section_header,
Expand Down Expand Up @@ -205,7 +211,8 @@ def __init__(

# Initialize RAI client
self.generated_rai_client = GeneratedRAIClient(
azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.credential
azure_ai_project=self.azure_ai_project,
token_manager=self.token_manager.credential,
)

# Initialize a cache for attack objectives by risk category and strategy
Expand Down Expand Up @@ -386,7 +393,12 @@ async def _get_attack_objectives(
if custom_objectives:
# Use custom objectives for this risk category
return await self._get_custom_attack_objectives(
risk_cat_value, num_objectives, num_objectives_with_subtypes, strategy, current_key, is_agent_target
risk_cat_value,
num_objectives,
num_objectives_with_subtypes,
strategy,
current_key,
is_agent_target,
)
else:
# No custom objectives for this risk category, but risk_categories was specified
Expand Down Expand Up @@ -574,7 +586,13 @@ async def _get_custom_attack_objectives(
self.prompt_to_risk_subtype[content] = risk_subtype

# Store in cache and return
self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives)
self._cache_attack_objectives(
current_key,
risk_cat_value,
strategy,
selected_prompts,
selected_cat_objectives,
)
return selected_prompts

async def _get_rai_attack_objectives(
Expand Down Expand Up @@ -680,12 +698,22 @@ async def _get_rai_attack_objectives(

# Filter and select objectives using num_objectives_with_subtypes
selected_cat_objectives = self._filter_and_select_objectives(
objectives_response, strategy, baseline_objectives_exist, baseline_key, num_objectives_with_subtypes
objectives_response,
strategy,
baseline_objectives_exist,
baseline_key,
num_objectives_with_subtypes,
)

# Extract content and cache
selected_prompts = self._extract_objective_content(selected_cat_objectives)
self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives)
self._cache_attack_objectives(
current_key,
risk_cat_value,
strategy,
selected_prompts,
selected_cat_objectives,
)

return selected_prompts

Expand Down Expand Up @@ -820,7 +848,11 @@ async def get_xpia_prompts_with_retry():

# Build the contexts list: XPIA context + any baseline contexts with agent fields
contexts = [
{"content": formatted_context, "context_type": context_type, "tool_name": tool_name}
{
"content": formatted_context,
"context_type": context_type,
"tool_name": tool_name,
}
]

# Add baseline contexts with agent fields as separate context entries
Expand Down Expand Up @@ -1362,10 +1394,13 @@ async def scan(

# Fetch attack objectives
all_objectives = await self._fetch_all_objectives(
flattened_attack_strategies, application_scenario, is_agent_target, client_id
flattened_attack_strategies,
application_scenario,
is_agent_target,
client_id,
)

chat_target = get_chat_target(target)
chat_target = get_chat_target(target, credential=self.credential)
self.chat_target = chat_target

# Execute attacks
Expand Down Expand Up @@ -1481,7 +1516,10 @@ async def _fetch_all_objectives(

# Calculate and log num_objectives_with_subtypes once globally
num_objectives = self.attack_objective_generator.num_objectives
max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in self.risk_categories), default=0)
max_num_subtypes = max(
(RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in self.risk_categories),
default=0,
)
num_objectives_with_subtypes = max(num_objectives, max_num_subtypes)

if num_objectives_with_subtypes != num_objectives:
Expand Down Expand Up @@ -1594,7 +1632,11 @@ async def _execute_attacks(
progress_bar.close()

async def _process_orchestrator_tasks(
self, orchestrator_tasks: List, parallel_execution: bool, max_parallel_tasks: int, timeout: int
self,
orchestrator_tasks: List,
parallel_execution: bool,
max_parallel_tasks: int,
timeout: int,
):
"""Process orchestrator tasks either in parallel or sequentially."""
if parallel_execution and orchestrator_tasks:
Expand Down Expand Up @@ -1629,7 +1671,12 @@ async def _process_orchestrator_tasks(
continue

async def _finalize_results(
self, skip_upload: bool, skip_evals: bool, eval_run, output_path: str, scan_name: str
self,
skip_upload: bool,
skip_evals: bool,
eval_run,
output_path: str,
scan_name: str,
) -> RedTeamResult:
"""Process and finalize scan results."""
log_section_header(self.logger, "Processing results")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from typing import Dict, List, Union, Optional, Any, Callable, cast
import logging

from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient
from azure.ai.evaluation.simulator._model_tools._generated_rai_client import (
GeneratedRAIClient,
)
from .._attack_strategy import AttackStrategy
from pyrit.prompt_converter import (
PromptConverter,
Expand Down Expand Up @@ -35,8 +37,10 @@
from .._default_converter import _DefaultConverter
from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget
from .._callback_chat_target import _CallbackChatTarget
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration

from azure.ai.evaluation._model_configurations import (
AzureOpenAIModelConfiguration,
OpenAIModelConfiguration,
)

# Azure OpenAI uses cognitive services scope for AAD authentication
AZURE_OPENAI_SCOPE = "https://cognitiveservices.azure.com/.default"
Expand Down Expand Up @@ -65,7 +69,9 @@ def get_token() -> str:


def create_tense_converter(
generated_rai_client: GeneratedRAIClient, is_one_dp_project: bool, logger: logging.Logger
generated_rai_client: GeneratedRAIClient,
is_one_dp_project: bool,
logger: logging.Logger,
) -> TenseConverter:
"""Factory function for creating TenseConverter with proper dependencies."""
converter_target = AzureRAIServiceTarget(
Expand Down Expand Up @@ -141,12 +147,22 @@ def _resolve_converter(strategy):


def get_chat_target(
target: Union[PromptChatTarget, Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
target: Union[
PromptChatTarget,
Callable,
AzureOpenAIModelConfiguration,
OpenAIModelConfiguration,
],
credential: Optional[Any] = None,
) -> PromptChatTarget:
"""Convert various target types to a PromptChatTarget.

:param target: The target to convert
:type target: Union[PromptChatTarget, Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
:param credential: Optional credential object with get_token method for AAD authentication.
Used as a fallback when target doesn't have an api_key or credential field. This is useful
in ACA environments where DefaultAzureCredential is not available.
:type credential: Optional[Any]
:return: A PromptChatTarget instance
:rtype: PromptChatTarget
"""
Expand All @@ -166,9 +182,9 @@ def _message_to_dict(message):
if not isinstance(target, Callable):
if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI
api_key = target.get("api_key", None)
credential = target.get("credential", None)
api_version = target.get("api_version", "2024-06-01")

# Check for credential in target dict or use passed credential parameter
target_credential = target.get("credential", None) or credential
if api_key:
# Use API key authentication
chat_target = OpenAIChatTarget(
Expand All @@ -177,13 +193,13 @@ def _message_to_dict(message):
api_key=api_key,
api_version=api_version,
)
elif credential:
elif target_credential:
# Use explicit TokenCredential for AAD auth (e.g., in ACA environments)
token_provider = _create_token_provider(credential)
token_provider = _create_token_provider(target_credential)
chat_target = OpenAIChatTarget(
model_name=target["azure_deployment"],
endpoint=target["azure_endpoint"],
api_key=token_provider, # Callable that returns tokens
api_key=token_provider, # PyRIT accepts callable that returns token
api_version=api_version,
)
else:
Expand Down Expand Up @@ -252,7 +268,12 @@ async def callback_target(
"context": {},
}
messages_list.append(formatted_response) # type: ignore
return {"messages": messages_list, "stream": stream, "session_state": session_state, "context": {}}
return {
"messages": messages_list,
"stream": stream,
"session_state": session_state,
"context": {},
}

chat_target = _CallbackChatTarget(callback=callback_target) # type: ignore

Expand Down
Loading