Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/code/scoring/8_scorer_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@
")\n",
"\n",
"if metrics:\n",
" objective_metrics = cast(ObjectiveScorerMetrics, metrics)\n",
" objective_metrics = cast(\"ObjectiveScorerMetrics\", metrics)\n",
" print(f\" Accuracy: {objective_metrics.accuracy}\")\n",
"else:\n",
" raise RuntimeError(\"Evaluation failed, no metrics returned\")"
Expand Down Expand Up @@ -604,7 +604,7 @@
")\n",
"\n",
"if metrics:\n",
" harm_metrics = cast(HarmScorerMetrics, metrics)\n",
" harm_metrics = cast(\"HarmScorerMetrics\", metrics)\n",
" print(f'Metrics for harm category \"{harm_metrics.harm_category}\" created')\n",
"else:\n",
" raise RuntimeError(\"Evaluation failed, no metrics returned\")"
Expand Down
4 changes: 2 additions & 2 deletions doc/code/scoring/8_scorer_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
)

if metrics:
objective_metrics = cast(ObjectiveScorerMetrics, metrics)
objective_metrics = cast("ObjectiveScorerMetrics", metrics)
print(f" Accuracy: {objective_metrics.accuracy}")
else:
raise RuntimeError("Evaluation failed, no metrics returned")
Expand Down Expand Up @@ -318,7 +318,7 @@
)

if metrics:
harm_metrics = cast(HarmScorerMetrics, metrics)
harm_metrics = cast("HarmScorerMetrics", metrics)
print(f'Metrics for harm category "{harm_metrics.harm_category}" created')
else:
raise RuntimeError("Evaluation failed, no metrics returned")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ select = [
"PIE", # https://docs.astral.sh/ruff/rules/#flake8-pie-pie
"RET", # https://docs.astral.sh/ruff/rules/#flake8-return-ret
"SIM", # https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
"TCH", # https://docs.astral.sh/ruff/rules/#flake8-type-checking-tch
"UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
"W", # https://docs.astral.sh/ruff/rules/#pycodestyle-w
]
Expand Down Expand Up @@ -293,7 +294,7 @@ notice-rgx = "Copyright \\(c\\) Microsoft Corporation\\.\\s*\\n.*Licensed under
# Ignore D and DOC rules everywhere except for the pyrit/ directory
"!pyrit/**.py" = ["D", "DOC"]
# Ignore copyright check only in doc/ directory
"doc/**" = ["CPY001"]
"doc/**" = ["CPY001", "TCH"]
# Temporary ignores for pyrit/ subdirectories until issue #1176
# https://github.com/Azure/PyRIT/issues/1176 is fully resolved
# TODO: Remove these ignores once the issues are fixed
Expand Down
9 changes: 5 additions & 4 deletions pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
import time
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Union, cast
from urllib.parse import urlparse

Expand All @@ -24,6 +23,8 @@
)

if TYPE_CHECKING:
from collections.abc import Callable

import azure.cognitiveservices.speech as speechsdk

from pyrit.auth.auth_config import REFRESH_TOKEN_BEFORE_MSEC
Expand Down Expand Up @@ -136,7 +137,7 @@ def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = "") -> str:
try:
credential = AzureCliCredential(tenant_id=tenant_id)
token = credential.get_token(scope)
return cast(str, token.token)
return cast("str", token.token)
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with tenant ID '{tenant_id}': {e}")
raise
Expand All @@ -158,7 +159,7 @@ def get_access_token_from_azure_msi(*, client_id: str, scope: str) -> str:
try:
credential = ManagedIdentityCredential(client_id=client_id)
token = credential.get_token(scope)
return cast(str, token.token)
return cast("str", token.token)
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}")
raise
Expand All @@ -179,7 +180,7 @@ def get_access_token_from_msa_public_client(*, client_id: str, scope: str) -> st
try:
app = msal.PublicClientApplication(client_id)
result = app.acquire_token_interactive(scopes=[scope])
return cast(str, result["access_token"])
return cast("str", result["access_token"])
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}")
raise
Expand Down
10 changes: 6 additions & 4 deletions pyrit/backend/mappers/attack_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@

import mimetypes
import uuid
from collections.abc import Sequence
from datetime import datetime, timezone
from typing import Optional, cast
from typing import TYPE_CHECKING, Optional, cast

from pyrit.backend.models.attacks import (
AddMessageRequest,
Expand All @@ -29,6 +28,9 @@
from pyrit.models import MessagePiece as PyritMessagePiece
from pyrit.models import Score as PyritScore

if TYPE_CHECKING:
from collections.abc import Sequence

# ============================================================================
# Domain → DTO (for API responses)
# ============================================================================
Expand Down Expand Up @@ -194,9 +196,9 @@ def request_piece_to_pyrit_message_piece(
return PyritMessagePiece(
role=role,
original_value=piece.original_value,
original_value_data_type=cast(PromptDataType, piece.data_type),
original_value_data_type=cast("PromptDataType", piece.data_type),
converted_value=piece.converted_value or piece.original_value,
converted_value_data_type=cast(PromptDataType, piece.data_type),
converted_value_data_type=cast("PromptDataType", piece.data_type),
conversation_id=conversation_id,
sequence=sequence,
prompt_metadata=metadata,
Expand Down
3 changes: 2 additions & 1 deletion pyrit/cli/frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import json
import logging
import sys
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional

Expand All @@ -43,6 +42,8 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i


if TYPE_CHECKING:
from collections.abc import Callable, Sequence

from pyrit.models.scenario_result import ScenarioResult
from pyrit.registry import (
InitializerMetadata,
Expand Down
6 changes: 4 additions & 2 deletions pyrit/common/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable


def print_deprecation_message(
Expand Down
2 changes: 1 addition & 1 deletion pyrit/common/json_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def read_json(file: IO[Any]) -> list[dict[str, str]]:
Returns:
List[Dict[str, str]]: Parsed JSON content.
"""
return cast(list[dict[str, str]], json.load(file))
return cast("list[dict[str, str]]", json.load(file))


def write_json(file: IO[Any], examples: list[dict[str, str]]) -> None:
Expand Down
10 changes: 6 additions & 4 deletions pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str
"""
self._validate_file_type(file_type)
with cache_file.open("r", encoding="utf-8") as file:
return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file))
return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file))

def _write_cache(self, *, cache_file: Path, examples: list[dict[str, str]], file_type: str) -> None:
"""
Expand Down Expand Up @@ -130,9 +130,11 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[st
if response.status_code == 200:
if file_type in FILE_TYPE_HANDLERS:
if file_type == "json":
return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)))
return cast(
"list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))
)
return cast(
list[dict[str, str]],
"list[dict[str, str]]",
FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))),
)
valid_types = ", ".join(FILE_TYPE_HANDLERS.keys())
Expand All @@ -155,7 +157,7 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str
"""
with open(source, encoding="utf-8") as file:
if file_type in FILE_TYPE_HANDLERS:
return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file))
return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file))
valid_types = ", ".join(FILE_TYPE_HANDLERS.keys())
raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.")

Expand Down
10 changes: 6 additions & 4 deletions pyrit/executor/attack/core/attack_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from __future__ import annotations

import dataclasses
import logging
import logging # noqa: TC003
import time
from abc import ABC
from dataclasses import dataclass, field
from typing import Any, Generic, Optional, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload

from pyrit.common.logger import logger
from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT
from pyrit.executor.core import (
Strategy,
Expand All @@ -28,7 +27,10 @@
ConversationReference,
Message,
)
from pyrit.prompt_target import PromptTarget

if TYPE_CHECKING:
from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.prompt_target import PromptTarget

AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext[Any]")
AttackStrategyResultT = TypeVar("AttackStrategyResultT", bound="AttackResult")
Expand Down
6 changes: 4 additions & 2 deletions pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import textwrap
from dataclasses import dataclass, field
from string import Formatter
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults
from pyrit.exceptions import ComponentRole, execution_context
Expand All @@ -28,7 +28,9 @@
)
from pyrit.prompt_normalizer import PromptNormalizer
from pyrit.prompt_target import PromptTarget
from pyrit.score import TrueFalseScorer

if TYPE_CHECKING:
from pyrit.score import TrueFalseScorer

logger = logging.getLogger(__name__)

Expand Down
8 changes: 5 additions & 3 deletions pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import json
import logging
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults
from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH
Expand Down Expand Up @@ -52,6 +51,9 @@
)
from pyrit.score.score_utils import normalize_score_to_float

if TYPE_CHECKING:
from collections.abc import Callable

logger = logging.getLogger(__name__)


Expand All @@ -78,7 +80,7 @@ def backtrack_count(self) -> int:
Returns:
int: The number of backtracks.
"""
return cast(int, self.metadata.get("backtrack_count", 0))
return cast("int", self.metadata.get("backtrack_count", 0))

@backtrack_count.setter
def backtrack_count(self, value: int) -> None:
Expand Down
16 changes: 9 additions & 7 deletions pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from __future__ import annotations

import logging
import logging # noqa: TC003
import uuid
from abc import ABC
from dataclasses import dataclass, field
from typing import Any, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeVar

from pyrit.common.logger import logger
from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT
Expand All @@ -16,11 +16,13 @@
AttackStrategy,
AttackStrategyResultT,
)
from pyrit.models import (
Message,
Score,
)
from pyrit.prompt_target import PromptTarget

if TYPE_CHECKING:
from pyrit.models import (
Message,
Score,
)
from pyrit.prompt_target import PromptTarget

MultiTurnAttackStrategyContextT = TypeVar("MultiTurnAttackStrategyContextT", bound="MultiTurnAttackContext[Any]")

Expand Down
9 changes: 6 additions & 3 deletions pyrit/executor/attack/multi_turn/red_teaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import enum
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults
from pyrit.common.path import EXECUTOR_RED_TEAM_PATH
Expand Down Expand Up @@ -38,7 +37,11 @@
SeedPrompt,
)
from pyrit.prompt_normalizer import PromptNormalizer
from pyrit.prompt_target.common.prompt_target import PromptTarget

if TYPE_CHECKING:
from collections.abc import Callable

from pyrit.prompt_target.common.prompt_target import PromptTarget

logger = logging.getLogger(__name__)

Expand Down
11 changes: 7 additions & 4 deletions pyrit/executor/attack/multi_turn/simulated_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

from pyrit.executor.attack.core.attack_config import (
AttackAdversarialConfig,
Expand All @@ -23,8 +22,12 @@
from pyrit.memory import CentralMemory
from pyrit.message_normalizer import ConversationContextNormalizer
from pyrit.models import Message, SeedPrompt, SeedSimulatedConversation
from pyrit.prompt_target import PromptChatTarget
from pyrit.score import TrueFalseScorer

if TYPE_CHECKING:
from pathlib import Path

from pyrit.prompt_target import PromptChatTarget
from pyrit.score import TrueFalseScorer

logger = logging.getLogger(__name__)

Expand Down
Loading