diff --git a/doc/code/scoring/8_scorer_metrics.ipynb b/doc/code/scoring/8_scorer_metrics.ipynb index a82bf84ec1..5dd26b483a 100644 --- a/doc/code/scoring/8_scorer_metrics.ipynb +++ b/doc/code/scoring/8_scorer_metrics.ipynb @@ -552,7 +552,7 @@ ")\n", "\n", "if metrics:\n", - " objective_metrics = cast(ObjectiveScorerMetrics, metrics)\n", + " objective_metrics = cast(\"ObjectiveScorerMetrics\", metrics)\n", " print(f\" Accuracy: {objective_metrics.accuracy}\")\n", "else:\n", " raise RuntimeError(\"Evaluation failed, no metrics returned\")" @@ -604,7 +604,7 @@ ")\n", "\n", "if metrics:\n", - " harm_metrics = cast(HarmScorerMetrics, metrics)\n", + " harm_metrics = cast(\"HarmScorerMetrics\", metrics)\n", " print(f'Metrics for harm category \"{harm_metrics.harm_category}\" created')\n", "else:\n", " raise RuntimeError(\"Evaluation failed, no metrics returned\")" diff --git a/doc/code/scoring/8_scorer_metrics.py b/doc/code/scoring/8_scorer_metrics.py index 2a1443da32..3430556b13 100644 --- a/doc/code/scoring/8_scorer_metrics.py +++ b/doc/code/scoring/8_scorer_metrics.py @@ -286,7 +286,7 @@ ) if metrics: - objective_metrics = cast(ObjectiveScorerMetrics, metrics) + objective_metrics = cast("ObjectiveScorerMetrics", metrics) print(f" Accuracy: {objective_metrics.accuracy}") else: raise RuntimeError("Evaluation failed, no metrics returned") @@ -318,7 +318,7 @@ ) if metrics: - harm_metrics = cast(HarmScorerMetrics, metrics) + harm_metrics = cast("HarmScorerMetrics", metrics) print(f'Metrics for harm category "{harm_metrics.harm_category}" created') else: raise RuntimeError("Evaluation failed, no metrics returned") diff --git a/pyproject.toml b/pyproject.toml index 9e9d492040..dfe62409a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -261,6 +261,7 @@ select = [ "PIE", # https://docs.astral.sh/ruff/rules/#flake8-pie-pie "RET", # https://docs.astral.sh/ruff/rules/#flake8-return-ret "SIM", # https://docs.astral.sh/ruff/rules/#flake8-simplify-sim + "TCH", # https://docs.astral.sh/ruff/rules/#flake8-type-checking-tch "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up "W", # https://docs.astral.sh/ruff/rules/#pycodestyle-w ] @@ -293,7 +294,7 @@ notice-rgx = "Copyright \\(c\\) Microsoft Corporation\\.\\s*\\n.*Licensed under # Ignore D and DOC rules everywhere except for the pyrit/ directory "!pyrit/**.py" = ["D", "DOC"] # Ignore copyright check only in doc/ directory -"doc/**" = ["CPY001"] +"doc/**" = ["CPY001", "TCH"] # Temporary ignores for pyrit/ subdirectories until issue #1176 # https://github.com/Azure/PyRIT/issues/1176 is fully resolved # TODO: Remove these ignores once the issues are fixed diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index a1f526545d..5192379382 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -5,7 +5,6 @@ import logging import time -from collections.abc import Callable from typing import TYPE_CHECKING, Any, Union, cast from urllib.parse import urlparse @@ -24,6 +23,8 @@ ) if TYPE_CHECKING: + from collections.abc import Callable + import azure.cognitiveservices.speech as speechsdk from pyrit.auth.auth_config import REFRESH_TOKEN_BEFORE_MSEC @@ -136,7 +137,7 @@ def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = "") -> str: try: credential = AzureCliCredential(tenant_id=tenant_id) token = credential.get_token(scope) - return cast(str, token.token) + return cast("str", token.token) except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with tenant ID '{tenant_id}': {e}") raise @@ -158,7 +159,7 @@ def get_access_token_from_azure_msi(*, client_id: str, scope: str) -> str: try: credential = ManagedIdentityCredential(client_id=client_id) token = credential.get_token(scope) - return cast(str, token.token) + return cast("str", token.token) except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") raise @@ -179,7 +180,7 @@ def get_access_token_from_msa_public_client(*, client_id: str, scope: str) -> st try: app = msal.PublicClientApplication(client_id) result = app.acquire_token_interactive(scopes=[scope]) - return cast(str, result["access_token"]) + return cast("str", result["access_token"]) except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") raise diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index e910c2b69c..c9c8dc7af2 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -12,9 +12,8 @@ import mimetypes import uuid -from collections.abc import Sequence from datetime import datetime, timezone -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast from pyrit.backend.models.attacks import ( AddMessageRequest, @@ -29,6 +28,9 @@ from pyrit.models import MessagePiece as PyritMessagePiece from pyrit.models import Score as PyritScore +if TYPE_CHECKING: + from collections.abc import Sequence + # ============================================================================ # Domain → DTO (for API responses) # ============================================================================ @@ -194,9 +196,9 @@ def request_piece_to_pyrit_message_piece( return PyritMessagePiece( role=role, original_value=piece.original_value, - original_value_data_type=cast(PromptDataType, piece.data_type), + original_value_data_type=cast("PromptDataType", piece.data_type), converted_value=piece.converted_value or piece.original_value, - converted_value_data_type=cast(PromptDataType, piece.data_type), + converted_value_data_type=cast("PromptDataType", piece.data_type), conversation_id=conversation_id, sequence=sequence, prompt_metadata=metadata, diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 809c3450ca..5c49525ec5 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -18,7 +18,6 @@ import json import logging import sys -from collections.abc import Callable, Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -43,6 +42,8 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from pyrit.models.scenario_result import ScenarioResult from pyrit.registry import ( InitializerMetadata, diff --git a/pyrit/common/deprecation.py b/pyrit/common/deprecation.py index b730f934fa..89da08bdb9 100644 --- a/pyrit/common/deprecation.py +++ b/pyrit/common/deprecation.py @@ -4,8 +4,10 @@ from __future__ import annotations import warnings -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable def print_deprecation_message( diff --git a/pyrit/common/json_helper.py b/pyrit/common/json_helper.py index 668b4335fb..17b49335de 100644 --- a/pyrit/common/json_helper.py +++ b/pyrit/common/json_helper.py @@ -12,7 +12,7 @@ def read_json(file: IO[Any]) -> list[dict[str, str]]: Returns: List[Dict[str, str]]: Parsed JSON content. """ - return cast(list[dict[str, str]], json.load(file)) + return cast("list[dict[str, str]]", json.load(file)) def write_json(file: IO[Any], examples: list[dict[str, str]]) -> None: diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index ff7babb9bd..5cd9212846 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -92,7 +92,7 @@ def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str """ self._validate_file_type(file_type) with cache_file.open("r", encoding="utf-8") as file: - return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file)) + return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file)) def _write_cache(self, *, cache_file: Path, examples: list[dict[str, str]], file_type: str) -> None: """ @@ -130,9 +130,11 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[st if response.status_code == 200: if file_type in FILE_TYPE_HANDLERS: if file_type == "json": - return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))) + return cast( + "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)) + ) return cast( - list[dict[str, str]], + "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), ) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) @@ -155,7 +157,7 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str """ with open(source, encoding="utf-8") as file: if file_type in FILE_TYPE_HANDLERS: - return cast(list[dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](file)) + return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file)) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index a8a6ecec67..dac86d10aa 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -4,14 +4,13 @@ from __future__ import annotations import dataclasses -import logging +import logging # noqa: TC003 import time from abc import ABC from dataclasses import dataclass, field -from typing import Any, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger -from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.core import ( Strategy, @@ -28,7 +27,10 @@ ConversationReference, Message, ) -from pyrit.prompt_target import PromptTarget + +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.prompt_target import PromptTarget AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext[Any]") AttackStrategyResultT = TypeVar("AttackStrategyResultT", bound="AttackResult") diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 3b8592018a..392aa0333e 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -5,7 +5,7 @@ import textwrap from dataclasses import dataclass, field from string import Formatter -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.exceptions import ComponentRole, execution_context @@ -28,7 +28,9 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget -from pyrit.score import TrueFalseScorer + +if TYPE_CHECKING: + from pyrit.score import TrueFalseScorer logger = logging.getLogger(__name__) diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index d9af83e540..cacf51c6e4 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -3,10 +3,9 @@ import json import logging -from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -52,6 +51,9 @@ ) from pyrit.score.score_utils import normalize_score_to_float +if TYPE_CHECKING: + from collections.abc import Callable + logger = logging.getLogger(__name__) @@ -78,7 +80,7 @@ def backtrack_count(self) -> int: Returns: int: The number of backtracks. """ - return cast(int, self.metadata.get("backtrack_count", 0)) + return cast("int", self.metadata.get("backtrack_count", 0)) @backtrack_count.setter def backtrack_count(self, value: int) -> None: diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 2bcdeb391e..5620c90032 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -3,11 +3,11 @@ from __future__ import annotations -import logging +import logging # noqa: TC003 import uuid from abc import ABC from dataclasses import dataclass, field -from typing import Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -16,11 +16,13 @@ AttackStrategy, AttackStrategyResultT, ) -from pyrit.models import ( - Message, - Score, -) -from pyrit.prompt_target import PromptTarget + +if TYPE_CHECKING: + from pyrit.models import ( + Message, + Score, + ) + from pyrit.prompt_target import PromptTarget MultiTurnAttackStrategyContextT = TypeVar("MultiTurnAttackStrategyContextT", bound="MultiTurnAttackContext[Any]") diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 26473e52ff..3c27842366 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -5,9 +5,8 @@ import enum import logging -from collections.abc import Callable from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_RED_TEAM_PATH @@ -38,7 +37,11 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptNormalizer -from pyrit.prompt_target.common.prompt_target import PromptTarget + +if TYPE_CHECKING: + from collections.abc import Callable + + from pyrit.prompt_target.common.prompt_target import PromptTarget logger = logging.getLogger(__name__) diff --git a/pyrit/executor/attack/multi_turn/simulated_conversation.py b/pyrit/executor/attack/multi_turn/simulated_conversation.py index cb77e55984..40c3bb515a 100644 --- a/pyrit/executor/attack/multi_turn/simulated_conversation.py +++ b/pyrit/executor/attack/multi_turn/simulated_conversation.py @@ -11,8 +11,7 @@ from __future__ import annotations import logging -from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from pyrit.executor.attack.core.attack_config import ( AttackAdversarialConfig, @@ -23,8 +22,12 @@ from pyrit.memory import CentralMemory from pyrit.message_normalizer import ConversationContextNormalizer from pyrit.models import Message, SeedPrompt, SeedSimulatedConversation -from pyrit.prompt_target import PromptChatTarget -from pyrit.score import TrueFalseScorer + +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.prompt_target import PromptChatTarget + from pyrit.score import TrueFalseScorer logger = logging.getLogger(__name__) diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 7414bcbdc6..1c9c0bc6ee 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -166,7 +166,7 @@ class TAPAttackResult(AttackResult): @property def tree_visualization(self) -> Optional[Tree]: """Get the tree visualization from metadata.""" - return cast(Optional[Tree], self.metadata.get("tree_visualization", None)) + return cast("Optional[Tree]", self.metadata.get("tree_visualization", None)) @tree_visualization.setter def tree_visualization(self, value: Tree) -> None: @@ -176,7 +176,7 @@ def tree_visualization(self, value: Tree) -> None: @property def nodes_explored(self) -> int: """Get the total number of nodes explored during the attack.""" - return cast(int, self.metadata.get("nodes_explored", 0)) + return cast("int", self.metadata.get("nodes_explored", 0)) @nodes_explored.setter def nodes_explored(self, value: int) -> None: @@ -186,7 +186,7 @@ def nodes_explored(self, value: int) -> None: @property def nodes_pruned(self) -> int: """Get the number of nodes pruned during the attack.""" - return cast(int, self.metadata.get("nodes_pruned", 0)) + return cast("int", self.metadata.get("nodes_pruned", 0)) @nodes_pruned.setter def nodes_pruned(self, value: int) -> None: @@ -196,7 +196,7 @@ def nodes_pruned(self, value: int) -> None: @property def max_depth_reached(self) -> int: """Get the maximum depth reached in the attack tree.""" - return cast(int, self.metadata.get("max_depth_reached", 0)) + return cast("int", self.metadata.get("max_depth_reached", 0)) @max_depth_reached.setter def max_depth_reached(self, value: int) -> None: @@ -206,7 +206,7 @@ def max_depth_reached(self, value: int) -> None: @property def auxiliary_scores_summary(self) -> dict[str, float]: """Get a summary of auxiliary scores from the best node.""" - return cast(dict[str, float], self.metadata.get("auxiliary_scores_summary", {})) + return cast("dict[str, float]", self.metadata.get("auxiliary_scores_summary", {})) @auxiliary_scores_summary.setter def auxiliary_scores_summary(self, value: dict[str, float]) -> None: @@ -216,7 +216,7 @@ def auxiliary_scores_summary(self, value: dict[str, float]) -> None: @property def best_adversarial_conversation_id(self) -> Optional[str]: """Get the adversarial conversation ID for the best-scoring branch.""" - return cast(Optional[str], self.metadata.get("best_adversarial_conversation_id", None)) + return cast("Optional[str]", self.metadata.get("best_adversarial_conversation_id", None)) @best_adversarial_conversation_id.setter def best_adversarial_conversation_id(self, value: Optional[str]) -> None: @@ -493,7 +493,7 @@ async def _generate_adversarial_prompt_async(self, objective: str) -> str: prompt = await self._generate_red_teaming_prompt_async(objective=objective) self.last_prompt_sent = prompt logger.debug(f"Node {self.node_id}: Generated adversarial prompt") - return cast(str, prompt) + return cast("str", prompt) async def _send_prompt_to_target_async(self, prompt: str) -> Message: """ @@ -1136,7 +1136,7 @@ def _parse_red_teaming_response(self, red_teaming_response: str) -> str: raise InvalidJsonException(message="The response from the red teaming chat is not in JSON format.") try: - return cast(str, red_teaming_response_dict["prompt"]) + return cast("str", red_teaming_response_dict["prompt"]) except KeyError: logger.error(f"The response from the red teaming chat does not contain a prompt: {red_teaming_response}") raise InvalidJsonException(message="The response from the red teaming chat does not contain a prompt.") @@ -1838,7 +1838,7 @@ def _create_attack_node( generate adversarial prompts and evaluate responses. """ node = _TreeOfAttacksNode( - objective_target=cast(PromptChatTarget, self._objective_target), + objective_target=cast("PromptChatTarget", self._objective_target), adversarial_chat=self._adversarial_chat, adversarial_chat_seed_prompt=self._adversarial_chat_seed_prompt, adversarial_chat_system_seed_prompt=self._adversarial_chat_system_seed_prompt, diff --git a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py index 4aae80cf90..500dacb898 100644 --- a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py +++ b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py @@ -33,7 +33,7 @@ def fetch_many_shot_jailbreaking_dataset() -> list[dict[str, str]]: source = "https://raw.githubusercontent.com/KutalVolkan/many-shot-jailbreaking-dataset/5eac855/examples.json" response = requests.get(source) response.raise_for_status() - return cast(list[dict[str, str]], response.json()) + return cast("list[dict[str, str]]", response.json()) class ManyShotJailbreakAttack(PromptSendingAttack): diff --git a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py index e3c0560bf5..508f83db1c 100644 --- a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py +++ b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py @@ -3,17 +3,19 @@ from __future__ import annotations -import logging +import logging # noqa: TC003 import uuid from abc import ABC from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.attack.core.attack_strategy import AttackContext, AttackStrategy from pyrit.models import AttackResult -from pyrit.prompt_target import PromptTarget + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget @dataclass diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index f728d27dba..39f8ce68f6 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -369,7 +369,7 @@ def get_last_context(self) -> Optional[FairnessBiasBenchmarkContext]: Optional[FairnessBiasBenchmarkContext]: The context from the most recent execution, or None if no execution has occurred """ - return cast(Optional[FairnessBiasBenchmarkContext], getattr(self, "_last_context", None)) + return cast("Optional[FairnessBiasBenchmarkContext]", getattr(self, "_last_context", None)) async def _teardown_async(self, *, context: FairnessBiasBenchmarkContext) -> None: """ diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 383465aacf..29fd4a8322 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -8,18 +8,20 @@ import logging import uuid from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, MutableMapping from contextlib import asynccontextmanager from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from pyrit.common import default_values from pyrit.common.logger import logger from pyrit.exceptions import clear_execution_context, get_execution_context from pyrit.models import StrategyResultT +if TYPE_CHECKING: + from collections.abc import AsyncIterator, MutableMapping + StrategyContextT = TypeVar("StrategyContextT", bound="StrategyContext") diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index adf3c97628..208c4040d7 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -7,7 +7,7 @@ import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional, Union, overload +from typing import TYPE_CHECKING, Any, Optional, Union, overload import yaml @@ -24,7 +24,9 @@ Message, ) from pyrit.prompt_normalizer import PromptNormalizer -from pyrit.prompt_target import PromptChatTarget + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptChatTarget logger = logging.getLogger(__name__) diff --git a/pyrit/executor/promptgen/core/prompt_generator_strategy.py b/pyrit/executor/promptgen/core/prompt_generator_strategy.py index 39eb0b6b31..29068611ba 100644 --- a/pyrit/executor/promptgen/core/prompt_generator_strategy.py +++ b/pyrit/executor/promptgen/core/prompt_generator_strategy.py @@ -3,7 +3,7 @@ from __future__ import annotations -import logging +import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass from typing import Optional, TypeVar diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index b86473188e..ebd67c2bda 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -8,7 +8,7 @@ import textwrap import uuid from dataclasses import dataclass, field -from typing import Any, Optional, Union, overload +from typing import TYPE_CHECKING, Any, Optional, Union, overload import numpy as np from colorama import Fore, Style @@ -23,7 +23,6 @@ PromptGeneratorStrategyContext, PromptGeneratorStrategyResult, ) -from pyrit.executor.promptgen.fuzzer.fuzzer_converter_base import FuzzerConverter from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory from pyrit.models import ( @@ -33,9 +32,12 @@ SeedPrompt, ) from pyrit.prompt_normalizer import NormalizerRequest, PromptNormalizer -from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.score import FloatScaleThresholdScorer, Scorer, SelfAskScaleScorer +if TYPE_CHECKING: + from pyrit.executor.promptgen.fuzzer.fuzzer_converter_base import FuzzerConverter + from pyrit.prompt_target import PromptChatTarget, PromptTarget + logger = logging.getLogger(__name__) diff --git a/pyrit/executor/workflow/core/workflow_strategy.py b/pyrit/executor/workflow/core/workflow_strategy.py index 695d3b3aae..80dc3a298d 100644 --- a/pyrit/executor/workflow/core/workflow_strategy.py +++ b/pyrit/executor/workflow/core/workflow_strategy.py @@ -3,7 +3,7 @@ from __future__ import annotations -import logging +import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass from typing import Optional, TypeVar diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 6fe26f445b..6702f22407 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -6,9 +6,8 @@ from collections.abc import MutableSequence, Sequence from contextlib import closing from datetime import datetime, timedelta, timezone -from typing import Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from azure.core.credentials import AccessToken from sqlalchemy import and_, create_engine, event, exists, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError @@ -31,6 +30,9 @@ MessagePiece, ) +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + logger = logging.getLogger(__name__) Model = TypeVar("Model") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c3cea3b483..67e6dcfb6d 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -11,12 +11,11 @@ from contextlib import closing from datetime import datetime from pathlib import Path -from typing import Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from sqlalchemy import MetaData, and_, or_ from sqlalchemy.engine.base import Engine from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.sql.elements import ColumnElement from pyrit.common.path import DB_DATA_PATH from pyrit.memory.memory_embedding import ( @@ -50,6 +49,9 @@ sort_message_pieces, ) +if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + logger = logging.getLogger(__name__) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 3a11938edf..0ebfa37946 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -4,7 +4,7 @@ import base64 import json import os -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union from pyrit.common import convert_local_image_to_data_url from pyrit.message_normalizer.message_normalizer import ( @@ -14,9 +14,11 @@ apply_system_message_behavior, ) from pyrit.models import ChatMessage, DataTypeSerializer, Message -from pyrit.models.literals import ChatMessageRole from pyrit.models.message_piece import MessagePiece +if TYPE_CHECKING: + from pyrit.models.literals import ChatMessageRole + # Supported audio formats for OpenAI input_audio # https://platform.openai.com/docs/guides/audio SUPPORTED_AUDIO_FORMATS = {".wav": "wav", ".mp3": "mp3"} diff --git a/pyrit/message_normalizer/tokenizer_template_normalizer.py b/pyrit/message_normalizer/tokenizer_template_normalizer.py index db31126efd..bd7e81b02f 100644 --- a/pyrit/message_normalizer/tokenizer_template_normalizer.py +++ b/pyrit/message_normalizer/tokenizer_template_normalizer.py @@ -122,10 +122,10 @@ def _load_tokenizer(model_name: str, token: Optional[str]) -> "PreTrainedTokeniz Returns: The loaded tokenizer. """ - from transformers import AutoTokenizer, PreTrainedTokenizerBase + from transformers import AutoTokenizer return cast( - PreTrainedTokenizerBase, + "PreTrainedTokenizerBase", AutoTokenizer.from_pretrained(model_name, token=token or None), # type: ignore[no-untyped-call, unused-ignore] ) diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index e40d0d228b..cd9efff5ce 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -5,14 +5,16 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar -from pyrit.identifiers.component_identifier import ComponentIdentifier -from pyrit.models.conversation_reference import ConversationReference, ConversationType -from pyrit.models.message_piece import MessagePiece -from pyrit.models.score import Score from pyrit.models.strategy_result import StrategyResult +if TYPE_CHECKING: + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models.message_piece import MessagePiece + from pyrit.models.score import Score + AttackResultT = TypeVar("AttackResultT", bound="AttackResult") diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index ca2473c278..c2004160fb 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -17,11 +17,11 @@ import aiofiles from pyrit.common.path import DB_DATA_PATH -from pyrit.models.literals import PromptDataType from pyrit.models.storage_io import DiskStorageIO, StorageIO if TYPE_CHECKING: from pyrit.memory import MemoryInterface + from pyrit.models.literals import PromptDataType # Define allowed categories for validation AllowedCategories = Literal["seed-prompt-entries", "prompt-memory-entries"] diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 91f604f2bc..ba6c09c937 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -6,14 +6,17 @@ import copy import uuid import warnings -from collections.abc import MutableSequence, Sequence from datetime import datetime -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from pyrit.common.utils import combine_dict -from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.message_piece import MessagePiece +if TYPE_CHECKING: + from collections.abc import MutableSequence, Sequence + + from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError + class Message: """ diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 9f1a0ea30e..62a6e4890d 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,12 +5,14 @@ import uuid from datetime import datetime -from typing import Any, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args from uuid import uuid4 from pyrit.identifiers.component_identifier import ComponentIdentifier from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError -from pyrit.models.score import Score + +if TYPE_CHECKING: + from pyrit.models.score import Score Originator = Literal["attack", "converter", "undefined", "scorer"] diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 25129a4091..c506a49a12 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -13,16 +13,19 @@ import logging import re import uuid -from collections.abc import Iterator, Sequence from dataclasses import dataclass, field from datetime import datetime -from pathlib import Path -from typing import Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from jinja2 import BaseLoader, Environment, StrictUndefined, Template, Undefined from pyrit.common.yaml_loadable import YamlLoadable -from pyrit.models.literals import PromptDataType + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + from pathlib import Path + + from pyrit.models.literals import PromptDataType logger = logging.getLogger(__name__) diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index 3a71c75625..b994f5108e 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -9,13 +9,16 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union -from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective +if TYPE_CHECKING: + from collections.abc import Sequence + + from pyrit.models.seeds.seed import Seed + class SeedAttackGroup(SeedGroup): """ diff --git a/pyrit/models/seeds/seed_attack_technique_group.py b/pyrit/models/seeds/seed_attack_technique_group.py index 0d7dd91693..7f5e16d28e 100644 --- a/pyrit/models/seeds/seed_attack_technique_group.py +++ b/pyrit/models/seeds/seed_attack_technique_group.py @@ -11,12 +11,15 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union -from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_group import SeedGroup +if TYPE_CHECKING: + from collections.abc import Sequence + + from pyrit.models.seeds.seed import Seed + class SeedAttackTechniqueGroup(SeedGroup): """ diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 6ae4ff91ce..413f216a97 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -12,22 +12,25 @@ import uuid import warnings from collections import defaultdict -from collections.abc import Sequence from datetime import datetime -from typing import Any, Optional, Union - -from pydantic.types import PositiveInt +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common import utils from pyrit.common.yaml_loadable import YamlLoadable -from pyrit.models.literals import PromptDataType, SeedType -from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_attack_group import SeedAttackGroup from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt from pyrit.models.seeds.seed_simulated_conversation import SeedSimulatedConversation +if TYPE_CHECKING: + from collections.abc import Sequence + + from pydantic.types import PositiveInt + + from pyrit.models.literals import PromptDataType, SeedType + from pyrit.models.seeds.seed import Seed + logger = logging.getLogger(__name__) diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index cff0ca81ee..0c25e41ae0 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -14,8 +14,7 @@ import uuid import warnings from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.yaml_loadable import YamlLoadable from pyrit.models.message import Message @@ -25,6 +24,9 @@ from pyrit.models.seeds.seed_prompt import SeedPrompt from pyrit.models.seeds.seed_simulated_conversation import SeedSimulatedConversation +if TYPE_CHECKING: + from collections.abc import Sequence + logger = logging.getLogger(__name__) diff --git a/pyrit/models/seeds/seed_objective.py b/pyrit/models/seeds/seed_objective.py index dde787ea05..85010e404e 100644 --- a/pyrit/models/seeds/seed_objective.py +++ b/pyrit/models/seeds/seed_objective.py @@ -9,12 +9,14 @@ import logging from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from pyrit.common.path import PATHS_DICT from pyrit.models.seeds.seed import Seed +if TYPE_CHECKING: + from pathlib import Path + logger = logging.getLogger(__name__) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 63859531af..b507cf3173 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -9,21 +9,22 @@ import logging import os -import uuid -from collections.abc import Sequence from dataclasses import dataclass, field -from pathlib import Path from typing import TYPE_CHECKING, Optional, Union from tinytag import TinyTag from pyrit.common.path import PATHS_DICT from pyrit.models import DataTypeSerializer -from pyrit.models.literals import ChatMessageRole, PromptDataType from pyrit.models.seeds.seed import Seed if TYPE_CHECKING: + import uuid + from collections.abc import Sequence + from pathlib import Path + from pyrit.models import Message + from pyrit.models.literals import ChatMessageRole, PromptDataType logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index 8488e1eaf2..8cbf4d8671 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -98,7 +98,7 @@ def _load_font(self) -> FreeTypeFont: font = ImageFont.truetype(self._font_name, self._font_size) except OSError: logger.warning(f"Cannot open font resource: {self._font_name}. Using default font.") - font = cast(FreeTypeFont, ImageFont.load_default()) + font = cast("FreeTypeFont", ImageFont.load_default()) return font def _add_text_to_image(self, text: str) -> Image.Image: diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index e1cf11f669..91fd265e57 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -100,7 +100,7 @@ def _load_font(self) -> FreeTypeFont: font = ImageFont.truetype(self._font_name, self._font_size) except OSError: logger.warning(f"Cannot open font resource: {self._font_name}. Using default font.") - font = cast(FreeTypeFont, ImageFont.load_default()) + font = cast("FreeTypeFont", ImageFont.load_default()) return font def _add_text_to_image(self, image: Image.Image) -> Image.Image: diff --git a/pyrit/prompt_converter/binary_converter.py b/pyrit/prompt_converter/binary_converter.py index 700cfd29b5..12b139f22a 100644 --- a/pyrit/prompt_converter/binary_converter.py +++ b/pyrit/prompt_converter/binary_converter.py @@ -4,12 +4,14 @@ from __future__ import annotations from enum import Enum -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ComponentIdentifier -from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy from pyrit.prompt_converter.word_level_converter import WordLevelConverter +if TYPE_CHECKING: + from pyrit.identifiers import ComponentIdentifier + from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy + class BinaryConverter(WordLevelConverter): """ diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index 48fea1d2ba..05c67f7ca4 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -7,17 +7,20 @@ import hashlib from dataclasses import dataclass from io import BytesIO -from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from docx import Document from pyrit.common.logger import logger -from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType, SeedPrompt, data_serializer_factory -from pyrit.models.data_type_serializer import DataTypeSerializer from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.identifiers import ComponentIdentifier + from pyrit.models.data_type_serializer import DataTypeSerializer + @dataclass class _WordDocInjectionConfig: diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 6eff804f66..15f1cc80b7 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -331,7 +331,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Decode the assistant's response from the generated token IDs assistant_response = cast( - str, + "str", self.tokenizer.decode(generated_tokens, skip_special_tokens=self.skip_special_tokens), ).strip() @@ -372,7 +372,7 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: # Apply the chat template to format and tokenize the messages return cast( - BatchEncoding, + "BatchEncoding", self.tokenizer.apply_chat_template( messages, tokenize=True, diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index d61923bc32..c338a80710 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -720,7 +720,7 @@ def _find_last_pending_tool_call(self, reply: Message) -> Optional[dict[str, Any continue if isinstance(section, dict) and section.get("type") == "function_call": # Do NOT skip function_call even if status == "completed" — we still need to emit the output. - return cast(dict[str, Any], section) + return cast("dict[str, Any]", section) return None async def _execute_call_section(self, tool_call_section: dict[str, Any]) -> dict[str, Any]: diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 3e915a37d5..276bbfc2c7 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -93,7 +93,9 @@ def __init__( """ super().__init__(**kwargs) - self._n_seconds: VideoSeconds = cast(VideoSeconds, str(n_seconds)) if isinstance(n_seconds, int) else n_seconds + self._n_seconds: VideoSeconds = ( + cast("VideoSeconds", str(n_seconds)) if isinstance(n_seconds, int) else n_seconds + ) self._validate_duration() self._size: VideoSize = self._validate_resolution(resolution_dimensions=resolution_dimensions) diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index 55946abc89..558ef70157 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -16,12 +16,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Iterator -from typing import Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from pyrit.identifiers import ComponentIdentifier from pyrit.registry.base import RegistryProtocol +if TYPE_CHECKING: + from collections.abc import Iterator + T = TypeVar("T") # The type of instances stored MetadataT = TypeVar("MetadataT", bound=ComponentIdentifier) diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index 70d9256b73..2b78ced2f9 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -11,14 +11,15 @@ from __future__ import annotations import random -from collections.abc import Sequence from typing import TYPE_CHECKING, Optional from pyrit.memory import CentralMemory from pyrit.models import SeedAttackGroup, SeedGroup -from pyrit.models.seeds.seed import Seed if TYPE_CHECKING: + from collections.abc import Sequence + + from pyrit.models.seeds.seed import Seed from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy # Key used when seed_groups are provided directly (not from a named dataset) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index d4b39f8b97..a3e260fe3a 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -20,7 +20,6 @@ from pyrit.common import REQUIRED_VALUE, apply_defaults from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack -from pyrit.identifiers import ComponentIdentifier from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry from pyrit.models import AttackResult @@ -36,6 +35,7 @@ if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedAttackGroup logger = logging.getLogger(__name__) @@ -340,7 +340,7 @@ def _get_baseline_data(self) -> tuple[list["SeedAttackGroup"], "AttackScoringCon # Import here to avoid circular imports from pyrit.executor.attack.core.attack_config import AttackScoringConfig - attack_scoring_config = AttackScoringConfig(objective_scorer=cast(TrueFalseScorer, self._objective_scorer)) + attack_scoring_config = AttackScoringConfig(objective_scorer=cast("TrueFalseScorer", self._objective_scorer)) if not attack_scoring_config: raise ValueError("Attack scoring config is required to create baseline attack.") diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index 336d35b9f1..c18ed60c14 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + """ Base class for scenario attack strategies with group-based aggregation. @@ -11,9 +13,11 @@ It also provides ScenarioCompositeStrategy for representing composed attack strategies. """ -from collections.abc import Sequence from enum import Enum -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from collections.abc import Sequence # TypeVar for the enum subclass itself T = TypeVar("T", bound="ScenarioStrategy") @@ -47,7 +51,7 @@ class ScenarioStrategy(Enum): _tags: set[str] - def __new__(cls, value: str, tags: set[str] | None = None) -> "ScenarioStrategy": + def __new__(cls, value: str, tags: set[str] | None = None) -> ScenarioStrategy: """ Create a new ScenarioStrategy with value and tags. @@ -195,10 +199,10 @@ def normalize_strategies(cls: type[T], strategies: set[T]) -> set[T]: @classmethod def prepare_scenario_strategies( cls: type[T], - strategies: Sequence[T | "ScenarioCompositeStrategy"] | None = None, + strategies: Sequence[T | ScenarioCompositeStrategy] | None = None, *, default_aggregate: T | None = None, - ) -> list["ScenarioCompositeStrategy"]: + ) -> list[ScenarioCompositeStrategy]: """ Prepare and normalize scenario strategies for use in a scenario. @@ -395,7 +399,7 @@ def is_single_strategy(self) -> bool: @staticmethod def extract_single_strategy_values( - composites: Sequence["ScenarioCompositeStrategy"], *, strategy_type: type[T] + composites: Sequence[ScenarioCompositeStrategy], *, strategy_type: type[T] ) -> set[str]: """ Extract strategy values from single-strategy composites. @@ -474,8 +478,8 @@ def get_composite_name(strategies: Sequence[ScenarioStrategy]) -> str: @staticmethod def normalize_compositions( - compositions: list["ScenarioCompositeStrategy"], *, strategy_type: type[T] - ) -> list["ScenarioCompositeStrategy"]: + compositions: list[ScenarioCompositeStrategy], *, strategy_type: type[T] + ) -> list[ScenarioCompositeStrategy]: """ Normalize strategy compositions by expanding aggregates while preserving concrete compositions. diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index e7ecf40650..a05085de4c 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -3,7 +3,7 @@ import logging import os -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message @@ -12,7 +12,6 @@ AttackAdversarialConfig, AttackScoringConfig, ) -from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.executor.attack.multi_turn.red_teaming import RedTeamingAttack from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.models import SeedAttackGroup, SeedObjective @@ -33,6 +32,9 @@ TrueFalseScorer, ) +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_strategy import AttackStrategy + logger = logging.getLogger(__name__) diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index 5232f2ce05..2ced987199 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -4,7 +4,7 @@ import logging import os from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from pyrit.common import apply_defaults from pyrit.common.path import ( @@ -21,7 +21,6 @@ AttackAdversarialConfig, AttackScoringConfig, ) -from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.models import SeedAttackGroup, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget from pyrit.scenario.core.atomic_attack import AtomicAttack @@ -40,6 +39,9 @@ TrueFalseScorer, ) +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_strategy import AttackStrategy + logger = logging.getLogger(__name__) PERSUASION_DECEPTION_PATH = Path(EXECUTOR_RED_TEAM_PATH, "persuasion_deception").resolve() diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index bba9703e7f..bfd51592ae 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -13,7 +13,7 @@ import os from collections.abc import Sequence from inspect import signature -from typing import Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message @@ -29,7 +29,6 @@ AttackConverterConfig, AttackScoringConfig, ) -from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.models import SeedAttackGroup, SeedObjective from pyrit.prompt_converter import ( AnsiAttackConverter, @@ -78,6 +77,9 @@ TrueFalseScoreAggregator, ) +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_strategy import AttackStrategy + AttackStrategyT = TypeVar("AttackStrategyT", bound="AttackStrategy[Any, Any]") logger = logging.getLogger(__name__) diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index 3e55f00c5a..333a2e812a 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -92,7 +92,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non response_error=original_piece.response_error, originator=original_piece.originator, original_prompt_id=( - cast(UUID, original_piece.original_prompt_id) + cast("UUID", original_piece.original_prompt_id) if isinstance(original_piece.original_prompt_id, str) else original_piece.original_prompt_id ), diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index 5a63853028..ae9e0acc4b 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -3,14 +3,16 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ComponentIdentifier -from pyrit.models import MessagePiece, Score, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator +if TYPE_CHECKING: + from pyrit.identifiers import ComponentIdentifier + from pyrit.models import MessagePiece, Score, UnvalidatedScore + from pyrit.prompt_target import PromptChatTarget + class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): """ diff --git a/pyrit/score/float_scale/video_float_scale_scorer.py b/pyrit/score/float_scale/video_float_scale_scorer.py index 14ce664017..4cadb91c1a 100644 --- a/pyrit/score/float_scale/video_float_scale_scorer.py +++ b/pyrit/score/float_scale/video_float_scale_scorer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional +from typing import TYPE_CHECKING, Optional from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score @@ -10,10 +10,12 @@ FloatScaleScorerByCategory, ) from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer -from pyrit.score.score_aggregator_result import ScoreAggregatorResult from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.video_scorer import _BaseVideoScorer +if TYPE_CHECKING: + from pyrit.score.score_aggregator_result import ScoreAggregatorResult + class VideoFloatScaleScorer( FloatScaleScorer, diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index bd5a1b2294..0680d243d4 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -9,7 +9,6 @@ import logging import uuid from abc import abstractmethod -from collections.abc import Sequence from typing import ( TYPE_CHECKING, Any, @@ -35,16 +34,18 @@ ScoreType, UnvalidatedScore, ) -from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.prompt_target.batch_helper import batch_task_async -from pyrit.score.scorer_prompt_validator import ScorerPromptValidator if TYPE_CHECKING: + from collections.abc import Sequence + + from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.score.scorer_evaluation.metrics_type import RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_evaluator import ( ScorerEvalDatasetFiles, ) from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics + from pyrit.score.scorer_prompt_validator import ScorerPromptValidator logger = logging.getLogger(__name__) @@ -410,7 +411,7 @@ async def score_prompts_batch_async( results = await batch_task_async( task_func=self.score_async, task_arguments=["message", "objective"], - prompt_target=cast(PromptTarget, prompt_target), + prompt_target=cast("PromptTarget", prompt_target), batch_size=batch_size, items_to_batch=[messages, objectives], role_filter=role_filter, diff --git a/pyrit/score/scorer_evaluation/human_labeled_dataset.py b/pyrit/score/scorer_evaluation/human_labeled_dataset.py index d0eab8d1a3..34dd7dc001 100644 --- a/pyrit/score/scorer_evaluation/human_labeled_dataset.py +++ b/pyrit/score/scorer_evaluation/human_labeled_dataset.py @@ -310,7 +310,7 @@ def from_csv( MessagePiece( role="assistant", original_value=response_to_score, - original_value_data_type=cast(PromptDataType, data_type), + original_value_data_type=cast("PromptDataType", data_type), ) ], ) diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index 2ff8022d65..d5ecfe1480 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -7,15 +7,12 @@ import logging import time from dataclasses import dataclass -from pathlib import Path -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast import numpy as np from scipy.stats import ttest_1samp from pyrit.common.path import SCORER_EVALS_PATH -from pyrit.models import Message -from pyrit.score import Scorer from pyrit.score.scorer_evaluation.human_labeled_dataset import ( HarmHumanLabeledEntry, HumanLabeledDataset, @@ -38,6 +35,12 @@ ) from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.models import Message + from pyrit.score import Scorer + logger = logging.getLogger(__name__) # Standard column names for evaluation datasets @@ -523,7 +526,7 @@ def _validate_and_extract_data( human_scores_list: list[list[float]] = [] for entry in labeled_dataset.entries: - harm_entry = cast(HarmHumanLabeledEntry, entry) + harm_entry = cast("HarmHumanLabeledEntry", entry) for message in harm_entry.conversation: self.scorer._memory.add_message_to_memory(request=message) assistant_responses.append(message) @@ -553,7 +556,7 @@ def _compute_metrics( diff[np.abs(diff) < 1e-10] = 0.0 abs_error = np.abs(diff) - t_statistic, p_value = cast(tuple[float, float], ttest_1samp(diff, 0)) + t_statistic, p_value = cast("tuple[float, float]", ttest_1samp(diff, 0)) num_responses = all_human_scores.shape[1] num_human_raters = all_human_scores.shape[0] @@ -624,7 +627,7 @@ def _validate_and_extract_data( objectives: list[str] = [] for entry in labeled_dataset.entries: - objective_entry = cast(ObjectiveHumanLabeledEntry, entry) + objective_entry = cast("ObjectiveHumanLabeledEntry", entry) for message in objective_entry.conversation: self.scorer._memory.add_message_to_memory(request=message) assistant_responses.append(message) diff --git a/pyrit/score/scorer_evaluation/scorer_metrics.py b/pyrit/score/scorer_evaluation/scorer_metrics.py index 560e53654e..81f254d9b9 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics.py @@ -5,14 +5,15 @@ import json from dataclasses import asdict, dataclass, field -from pathlib import Path from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union -import numpy as np - from pyrit.common.utils import verify_and_resolve_path if TYPE_CHECKING: + from pathlib import Path + + import numpy as np + from pyrit.identifiers import ComponentIdentifier from pyrit.models.harm_definition import HarmDefinition diff --git a/pyrit/score/true_false/question_answer_scorer.py b/pyrit/score/true_false/question_answer_scorer.py index 5d6010376a..ea6b0b4e4c 100644 --- a/pyrit/score/true_false/question_answer_scorer.py +++ b/pyrit/score/true_false/question_answer_scorer.py @@ -3,9 +3,8 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( @@ -14,6 +13,9 @@ ) from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +if TYPE_CHECKING: + from pyrit.identifiers import ComponentIdentifier + class QuestionAnswerScorer(TrueFalseScorer): """ diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 6e2c9948bf..44bb362748 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -3,11 +3,8 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ComponentIdentifier -from pyrit.models import MessagePiece, Score, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -15,6 +12,11 @@ ) from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +if TYPE_CHECKING: + from pyrit.identifiers import ComponentIdentifier + from pyrit.models import MessagePiece, Score, UnvalidatedScore + from pyrit.prompt_target import PromptChatTarget + class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): """ diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index 93a1dcfcc6..bf1c017dde 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -3,13 +3,10 @@ from __future__ import annotations -import pathlib -from typing import Optional +from typing import TYPE_CHECKING, Optional from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.common.utils import verify_and_resolve_path -from pyrit.models import MessagePiece, Score, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.self_ask_true_false_scorer import SelfAskTrueFalseScorer from pyrit.score.true_false.true_false_score_aggregator import ( @@ -17,6 +14,12 @@ TrueFalseScoreAggregator, ) +if TYPE_CHECKING: + import pathlib + + from pyrit.models import MessagePiece, Score, UnvalidatedScore + from pyrit.prompt_target import PromptChatTarget + class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): """ diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index fdd1f31bdf..f31e4640e6 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -297,7 +297,6 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: ValueError: If an initializer name is not found in the registry. """ from pyrit.registry import InitializerRegistry - from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer if not self._initializer_configs: return [] diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index ef49184977..89c0348979 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -182,7 +182,7 @@ def with_prompt_normalizer(self) -> "AttackBuilder": """Add a mock prompt normalizer.""" normalizer = MagicMock(spec=PromptNormalizer) normalizer.send_prompt_async = AsyncMock(return_value=None) - self.prompt_normalizer = cast(PromptNormalizer, normalizer) + self.prompt_normalizer = cast("PromptNormalizer", normalizer) return self def build(self) -> TreeOfAttacksWithPruningAttack: @@ -225,7 +225,7 @@ def _create_mock_target() -> PromptTarget: class_name="MockTarget", class_module="test_module", ) - return cast(PromptTarget, target) + return cast("PromptTarget", target) @staticmethod def _create_mock_chat() -> PromptChatTarget: @@ -236,7 +236,7 @@ def _create_mock_chat() -> PromptChatTarget: class_name="MockChatTarget", class_module="test_module", ) - return cast(PromptChatTarget, chat) + return cast("PromptChatTarget", chat) @staticmethod def _create_mock_scorer(name: str) -> TrueFalseScorer: @@ -247,7 +247,7 @@ def _create_mock_scorer(name: str) -> TrueFalseScorer: class_name=name, class_module="test_module", ) - return cast(TrueFalseScorer, scorer) + return cast("TrueFalseScorer", scorer) @staticmethod def _create_mock_aux_scorer(name: str) -> Scorer: @@ -259,7 +259,7 @@ def _create_mock_aux_scorer(name: str) -> Scorer: class_name=name, class_module="test_module", ) - return cast(Scorer, scorer) + return cast("Scorer", scorer) class TestHelpers: diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 8fadfb013d..3106409fde 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -3,8 +3,7 @@ import uuid -from collections.abc import Sequence -from typing import Optional +from typing import TYPE_CHECKING, Optional from pyrit.common.utils import to_sha256 from pyrit.identifiers import ComponentIdentifier @@ -19,6 +18,9 @@ Score, ) +if TYPE_CHECKING: + from collections.abc import Sequence + def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_categories=None, labels=None): """Helper function to create MessagePiece with optional targeted harm categories and labels.""" diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 6001c2756a..4316270259 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -4,16 +4,19 @@ import os import uuid from collections.abc import Generator, MutableSequence, Sequence +from typing import TYPE_CHECKING import pytest from pyrit.memory import AzureSQLMemory, EmbeddingDataEntry, PromptMemoryEntry -from pyrit.memory.memory_models import Base from pyrit.models import MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget from unit.mocks import get_azure_sql_memory, get_sample_conversation_entries +if TYPE_CHECKING: + from pyrit.memory.memory_models import Base + @pytest.fixture def memory_interface() -> Generator[AzureSQLMemory, None, None]: