From 3ca0f70c623782594400cbca12b97c0704f59c4c Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 23 Feb 2026 10:31:15 -0800 Subject: [PATCH 1/5] Enable ruff TCH (flake8-type-checking) rules and fix all violations Move type-only imports into TYPE_CHECKING blocks to improve runtime performance: - Move first-party imports used only in annotations (TC001) - Move third-party imports used only in annotations (TC002) - Move stdlib imports used only in annotations (TC003) - Fix string union in runtime annotation with __future__ annotations (TC010) - Ignore TC006 (quoting cast types is unnecessary for builtins) - Clean up resulting isort and unused import issues Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyproject.toml | 2 ++ .../remote/remote_dataset_loader.py | 6 ++++-- pyrit/executor/attack/core/attack_strategy.py | 11 +++++++---- .../attack/multi_turn/chunked_request.py | 6 ++++-- .../multi_turn/multi_turn_attack_strategy.py | 17 ++++++++++------- pyrit/executor/attack/multi_turn/red_teaming.py | 6 ++++-- .../attack/multi_turn/simulated_conversation.py | 11 +++++++---- .../attack/single_turn/many_shot_jailbreak.py | 2 +- .../single_turn/single_turn_attack_strategy.py | 9 ++++++--- pyrit/executor/promptgen/anecdoctor.py | 6 ++++-- .../promptgen/core/prompt_generator_strategy.py | 6 ++++-- pyrit/executor/promptgen/fuzzer/fuzzer.py | 8 +++++--- .../executor/workflow/core/workflow_strategy.py | 6 ++++-- pyrit/identifiers/scorer_identifier.py | 6 ++++-- pyrit/memory/azure_sql_memory.py | 6 ++++-- pyrit/memory/memory_interface.py | 6 ++++-- .../chat_message_normalizer.py | 6 ++++-- .../tokenizer_template_normalizer.py | 4 ++-- pyrit/models/attack_result.py | 12 +++++++----- pyrit/models/data_type_serializer.py | 2 +- pyrit/models/message.py | 6 ++++-- pyrit/models/message_piece.py | 6 ++++-- pyrit/models/seeds/seed.py | 9 ++++++--- pyrit/models/seeds/seed_attack_group.py | 6 ++++-- pyrit/models/seeds/seed_dataset.py | 12 +++++++----- pyrit/models/seeds/seed_objective.py | 6 ++++-- pyrit/models/seeds/seed_prompt.py | 7 ++++--- pyrit/prompt_converter/binary_converter.py | 8 +++++--- pyrit/prompt_converter/word_doc_converter.py | 11 +++++++---- .../hugging_face/hugging_face_chat_target.py | 4 ++-- .../openai/openai_response_target.py | 2 +- pyrit/scenario/core/dataset_configuration.py | 2 +- pyrit/scenario/core/scenario.py | 2 +- pyrit/scenario/core/scenario_strategy.py | 2 ++ pyrit/scenario/scenarios/airt/cyber.py | 6 ++++-- pyrit/scenario/scenarios/airt/scam.py | 6 ++++-- .../scenarios/foundry/red_team_agent.py | 6 ++++-- .../self_ask_general_float_scale_scorer.py | 10 ++++++---- .../float_scale/video_float_scale_scorer.py | 6 ++++-- pyrit/score/scorer.py | 2 +- .../score/scorer_evaluation/scorer_evaluator.py | 11 +++++++---- pyrit/score/scorer_evaluation/scorer_metrics.py | 7 ++++--- .../score/true_false/question_answer_scorer.py | 6 ++++-- .../self_ask_general_true_false_scorer.py | 10 ++++++---- .../self_ask_question_answer_scorer.py | 11 +++++++---- pyrit/setup/configuration_loader.py | 1 - tests/unit/memory/test_azure_sql_memory.py | 6 ++++-- 47 files changed, 195 insertions(+), 114 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f7fad04abe..a91eaf806a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -251,6 +251,7 @@ select = [ "DOC", # https://docs.astral.sh/ruff/rules/#pydoclint-doc "F401", # unused-import "I", # isort + "TCH", # https://docs.astral.sh/ruff/rules/#flake8-type-checking-tch ] ignore = [ "D100", # Missing docstring in public module @@ -259,6 +260,7 @@ ignore = [ "D212", # Multi-line docstring summary should start at the first line "D301", # Use r""" if any backslashes in a docstring "DOC502", # Raised exception is not explicitly raised + "TC006", # runtime-cast-value (quoting builtin types in cast() is unnecessary) ] extend-select = [ "D204", # 1 blank line required after class docstring diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 1c9d538824..21e2ace12f 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -129,10 +129,12 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> List[Dict[st if response.status_code == 200: if file_type in FILE_TYPE_HANDLERS: if file_type == "json": - return cast(List[Dict[str, str]], FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))) + return cast( + "List[Dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)) + ) else: return cast( - List[Dict[str, str]], + "List[Dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), ) else: diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index a442d81548..6584720608 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -4,14 +4,12 @@ from __future__ import annotations import dataclasses -import logging import time from abc import ABC from dataclasses import dataclass, field -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, overload +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, TypeVar, overload from pyrit.common.logger import logger -from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.core import ( Strategy, @@ -28,7 +26,12 @@ ConversationReference, Message, ) -from pyrit.prompt_target import PromptTarget + +if TYPE_CHECKING: + import logging + + from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.prompt_target import PromptTarget AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext[Any]") AttackStrategyResultT = TypeVar("AttackStrategyResultT", bound="AttackResult") diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index feabb98215..57dbbf45d9 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -5,7 +5,7 @@ import textwrap from dataclasses import dataclass, field from string import Formatter -from typing import Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.exceptions import ComponentRole, execution_context @@ -28,7 +28,9 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget -from pyrit.score import TrueFalseScorer + +if TYPE_CHECKING: + from pyrit.score import TrueFalseScorer logger = logging.getLogger(__name__) diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 6de7127969..174ef124ec 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -3,11 +3,10 @@ from __future__ import annotations -import logging import uuid from abc import ABC from dataclasses import dataclass, field -from typing import Any, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -16,11 +15,15 @@ AttackStrategy, AttackStrategyResultT, ) -from pyrit.models import ( - Message, - Score, -) -from pyrit.prompt_target import PromptTarget + +if TYPE_CHECKING: + import logging + + from pyrit.models import ( + Message, + Score, + ) + from pyrit.prompt_target import PromptTarget MultiTurnAttackStrategyContextT = TypeVar("MultiTurnAttackStrategyContextT", bound="MultiTurnAttackContext[Any]") diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 33b2c75d75..7a89a9c769 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -6,7 +6,7 @@ import enum import logging from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_RED_TEAM_PATH @@ -37,7 +37,9 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptNormalizer -from pyrit.prompt_target.common.prompt_target import PromptTarget + +if TYPE_CHECKING: + from pyrit.prompt_target.common.prompt_target import PromptTarget logger = logging.getLogger(__name__) diff --git a/pyrit/executor/attack/multi_turn/simulated_conversation.py b/pyrit/executor/attack/multi_turn/simulated_conversation.py index a8144e2d01..c3e5a67859 100644 --- a/pyrit/executor/attack/multi_turn/simulated_conversation.py +++ b/pyrit/executor/attack/multi_turn/simulated_conversation.py @@ -11,8 +11,7 @@ from __future__ import annotations import logging -from pathlib import Path -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from pyrit.executor.attack.core.attack_config import ( AttackAdversarialConfig, @@ -23,8 +22,12 @@ from pyrit.memory import CentralMemory from pyrit.message_normalizer import ConversationContextNormalizer from pyrit.models import Message, SeedPrompt, SeedSimulatedConversation -from pyrit.prompt_target import PromptChatTarget -from pyrit.score import TrueFalseScorer + +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.prompt_target import PromptChatTarget + from pyrit.score import TrueFalseScorer logger = logging.getLogger(__name__) diff --git a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py index 4aae80cf90..500dacb898 100644 --- a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py +++ b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py @@ -33,7 +33,7 @@ def fetch_many_shot_jailbreaking_dataset() -> list[dict[str, str]]: source = "https://raw.githubusercontent.com/KutalVolkan/many-shot-jailbreaking-dataset/5eac855/examples.json" response = requests.get(source) response.raise_for_status() - return cast(list[dict[str, str]], response.json()) + return cast("list[dict[str, str]]", response.json()) class ManyShotJailbreakAttack(PromptSendingAttack): diff --git a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py index 7a8ff9d399..715d2e79ff 100644 --- a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py +++ b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py @@ -3,17 +3,20 @@ from __future__ import annotations -import logging import uuid from abc import ABC from dataclasses import dataclass, field -from typing import Any, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Type, Union from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.attack.core.attack_strategy import AttackContext, AttackStrategy from pyrit.models import AttackResult -from pyrit.prompt_target import PromptTarget + +if TYPE_CHECKING: + import logging + + from pyrit.prompt_target import PromptTarget @dataclass diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 82ecb25e5f..949a76f425 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -7,7 +7,7 @@ import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, overload +from typing import TYPE_CHECKING, Any, Dict, List, Optional, overload import yaml @@ -24,7 +24,9 @@ Message, ) from pyrit.prompt_normalizer import PromptNormalizer -from pyrit.prompt_target import PromptChatTarget + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptChatTarget logger = logging.getLogger(__name__) diff --git a/pyrit/executor/promptgen/core/prompt_generator_strategy.py b/pyrit/executor/promptgen/core/prompt_generator_strategy.py index 39eb0b6b31..e0af009924 100644 --- a/pyrit/executor/promptgen/core/prompt_generator_strategy.py +++ b/pyrit/executor/promptgen/core/prompt_generator_strategy.py @@ -3,10 +3,9 @@ from __future__ import annotations -import logging from abc import ABC from dataclasses import dataclass -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -17,6 +16,9 @@ ) from pyrit.models import StrategyResult +if TYPE_CHECKING: + import logging + PromptGeneratorStrategyContextT = TypeVar("PromptGeneratorStrategyContextT", bound="PromptGeneratorStrategyContext") PromptGeneratorStrategyResultT = TypeVar("PromptGeneratorStrategyResultT", bound="PromptGeneratorStrategyResult") diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 93360a16dd..0255f0f484 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -8,7 +8,7 @@ import textwrap import uuid from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, overload import numpy as np from colorama import Fore, Style @@ -23,7 +23,6 @@ PromptGeneratorStrategyContext, PromptGeneratorStrategyResult, ) -from pyrit.executor.promptgen.fuzzer.fuzzer_converter_base import FuzzerConverter from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.memory import CentralMemory from pyrit.models import ( @@ -33,9 +32,12 @@ SeedPrompt, ) from pyrit.prompt_normalizer import NormalizerRequest, PromptNormalizer -from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.score import FloatScaleThresholdScorer, Scorer, SelfAskScaleScorer +if TYPE_CHECKING: + from pyrit.executor.promptgen.fuzzer.fuzzer_converter_base import FuzzerConverter + from pyrit.prompt_target import PromptChatTarget, PromptTarget + logger = logging.getLogger(__name__) diff --git a/pyrit/executor/workflow/core/workflow_strategy.py b/pyrit/executor/workflow/core/workflow_strategy.py index 1b6fab4d14..0f16107a29 100644 --- a/pyrit/executor/workflow/core/workflow_strategy.py +++ b/pyrit/executor/workflow/core/workflow_strategy.py @@ -3,10 +3,9 @@ from __future__ import annotations -import logging from abc import ABC from dataclasses import dataclass -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -18,6 +17,9 @@ ) from pyrit.models import StrategyResult +if TYPE_CHECKING: + import logging + WorkflowContextT = TypeVar("WorkflowContextT", bound="WorkflowContext") WorkflowResultT = TypeVar("WorkflowResultT", bound="WorkflowResult") diff --git a/pyrit/identifiers/scorer_identifier.py b/pyrit/identifiers/scorer_identifier.py index 8dac5fc676..0f7b3b7e10 100644 --- a/pyrit/identifiers/scorer_identifier.py +++ b/pyrit/identifiers/scorer_identifier.py @@ -4,10 +4,12 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from pyrit.identifiers.identifier import _MAX_STORAGE_LENGTH, Identifier -from pyrit.models.score import ScoreType + +if TYPE_CHECKING: + from pyrit.models.score import ScoreType @dataclass(frozen=True) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 52c8cbf6d8..bb0b2553a8 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -5,9 +5,8 @@ import struct from contextlib import closing from datetime import datetime, timedelta, timezone -from typing import Any, MutableSequence, Optional, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, Any, MutableSequence, Optional, Sequence, TypeVar, Union -from azure.core.credentials import AccessToken from sqlalchemy import and_, create_engine, event, exists, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError @@ -30,6 +29,9 @@ MessagePiece, ) +if TYPE_CHECKING: + from azure.core.credentials import AccessToken + logger = logging.getLogger(__name__) Model = TypeVar("Model") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 8a09b63b80..2fe80c817b 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -10,12 +10,11 @@ from contextlib import closing from datetime import datetime from pathlib import Path -from typing import Any, MutableSequence, Optional, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, Any, MutableSequence, Optional, Sequence, TypeVar, Union from sqlalchemy import MetaData, and_, or_ from sqlalchemy.engine.base import Engine from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.sql.elements import ColumnElement from pyrit.common.path import DB_DATA_PATH from pyrit.memory.memory_embedding import ( @@ -49,6 +48,9 @@ sort_message_pieces, ) +if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + logger = logging.getLogger(__name__) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 1f4e95d9cf..529012c43d 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -4,7 +4,7 @@ import base64 import json import os -from typing import Any, List, Union +from typing import TYPE_CHECKING, Any, List, Union from pyrit.common import convert_local_image_to_data_url from pyrit.message_normalizer.message_normalizer import ( @@ -14,9 +14,11 @@ apply_system_message_behavior, ) from pyrit.models import ChatMessage, DataTypeSerializer, Message -from pyrit.models.literals import ChatMessageRole from pyrit.models.message_piece import MessagePiece +if TYPE_CHECKING: + from pyrit.models.literals import ChatMessageRole + # Supported audio formats for OpenAI input_audio # https://platform.openai.com/docs/guides/audio SUPPORTED_AUDIO_FORMATS = {".wav": "wav", ".mp3": "mp3"} diff --git a/pyrit/message_normalizer/tokenizer_template_normalizer.py b/pyrit/message_normalizer/tokenizer_template_normalizer.py index b62e3b5234..1d12bcf51e 100644 --- a/pyrit/message_normalizer/tokenizer_template_normalizer.py +++ b/pyrit/message_normalizer/tokenizer_template_normalizer.py @@ -122,10 +122,10 @@ def _load_tokenizer(model_name: str, token: Optional[str]) -> "PreTrainedTokeniz Returns: The loaded tokenizer. """ - from transformers import AutoTokenizer, PreTrainedTokenizerBase + from transformers import AutoTokenizer return cast( - PreTrainedTokenizerBase, + "PreTrainedTokenizerBase", AutoTokenizer.from_pretrained(model_name, token=token or None), # type: ignore[no-untyped-call, unused-ignore] ) diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 3518ce4626..e605f610d0 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -5,14 +5,16 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar -from pyrit.identifiers import AttackIdentifier -from pyrit.models.conversation_reference import ConversationReference, ConversationType -from pyrit.models.message_piece import MessagePiece -from pyrit.models.score import Score from pyrit.models.strategy_result import StrategyResult +if TYPE_CHECKING: + from pyrit.identifiers import AttackIdentifier + from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models.message_piece import MessagePiece + from pyrit.models.score import Score + AttackResultT = TypeVar("AttackResultT", bound="AttackResult") diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 4ace4e2bba..a683052c26 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -17,11 +17,11 @@ import aiofiles from pyrit.common.path import DB_DATA_PATH -from pyrit.models.literals import PromptDataType from pyrit.models.storage_io import DiskStorageIO, StorageIO if TYPE_CHECKING: from pyrit.memory import MemoryInterface + from pyrit.models.literals import PromptDataType # Define allowed categories for validation AllowedCategories = Literal["seed-prompt-entries", "prompt-memory-entries"] diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 509d70cb29..c20e996688 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -7,12 +7,14 @@ import uuid import warnings from datetime import datetime -from typing import Dict, MutableSequence, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, MutableSequence, Optional, Sequence, Union from pyrit.common.utils import combine_dict -from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.message_piece import MessagePiece +if TYPE_CHECKING: + from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError + class Message: """ diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index f07b045318..885140eb24 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,12 +5,14 @@ import uuid from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, get_args from uuid import uuid4 from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError -from pyrit.models.score import Score + +if TYPE_CHECKING: + from pyrit.models.score import Score Originator = Literal["attack", "converter", "undefined", "scorer"] diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index f2a6bc4697..dbfeb012d0 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -15,13 +15,16 @@ import uuid from dataclasses import dataclass, field from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Iterator, Optional, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Sequence, TypeVar, Union from jinja2 import BaseLoader, Environment, StrictUndefined, Template, Undefined from pyrit.common.yaml_loadable import YamlLoadable -from pyrit.models.literals import PromptDataType + +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.models.literals import PromptDataType logger = logging.getLogger(__name__) diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index f16d17aa1c..0514414d1e 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -9,12 +9,14 @@ from __future__ import annotations -from typing import Any, Dict, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Sequence, Union -from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective +if TYPE_CHECKING: + from pyrit.models.seeds.seed import Seed + class SeedAttackGroup(SeedGroup): """ diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index ec20f3b208..f3d144edd6 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -13,20 +13,22 @@ import warnings from collections import defaultdict from datetime import datetime -from typing import Any, Dict, Optional, Sequence, Union - -from pydantic.types import PositiveInt +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union from pyrit.common import utils from pyrit.common.yaml_loadable import YamlLoadable -from pyrit.models.literals import PromptDataType, SeedType -from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_attack_group import SeedAttackGroup from pyrit.models.seeds.seed_group import SeedGroup from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt from pyrit.models.seeds.seed_simulated_conversation import SeedSimulatedConversation +if TYPE_CHECKING: + from pydantic.types import PositiveInt + + from pyrit.models.literals import PromptDataType, SeedType + from pyrit.models.seeds.seed import Seed + logger = logging.getLogger(__name__) diff --git a/pyrit/models/seeds/seed_objective.py b/pyrit/models/seeds/seed_objective.py index fe4ae90c6b..6faf36f272 100644 --- a/pyrit/models/seeds/seed_objective.py +++ b/pyrit/models/seeds/seed_objective.py @@ -9,12 +9,14 @@ import logging from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from pyrit.common.path import PATHS_DICT from pyrit.models.seeds.seed import Seed +if TYPE_CHECKING: + from pathlib import Path + logger = logging.getLogger(__name__) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 8d478f7faa..d51365e780 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -9,20 +9,21 @@ import logging import os -import uuid from dataclasses import dataclass, field -from pathlib import Path from typing import TYPE_CHECKING, Optional, Sequence, Union from tinytag import TinyTag from pyrit.common.path import PATHS_DICT from pyrit.models import DataTypeSerializer -from pyrit.models.literals import ChatMessageRole, PromptDataType from pyrit.models.seeds.seed import Seed if TYPE_CHECKING: + import uuid + from pathlib import Path + from pyrit.models import Message + from pyrit.models.literals import ChatMessageRole, PromptDataType logger = logging.getLogger(__name__) diff --git a/pyrit/prompt_converter/binary_converter.py b/pyrit/prompt_converter/binary_converter.py index 245da36f71..79c3f95ba9 100644 --- a/pyrit/prompt_converter/binary_converter.py +++ b/pyrit/prompt_converter/binary_converter.py @@ -4,12 +4,14 @@ from __future__ import annotations from enum import Enum -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ConverterIdentifier -from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy from pyrit.prompt_converter.word_level_converter import WordLevelConverter +if TYPE_CHECKING: + from pyrit.identifiers import ConverterIdentifier + from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy + class BinaryConverter(WordLevelConverter): """ diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index 63725afc7f..c7212ae9ab 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -7,17 +7,20 @@ import hashlib from dataclasses import dataclass from io import BytesIO -from pathlib import Path -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from docx import Document from pyrit.common.logger import logger -from pyrit.identifiers import ConverterIdentifier from pyrit.models import PromptDataType, SeedPrompt, data_serializer_factory -from pyrit.models.data_type_serializer import DataTypeSerializer from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.identifiers import ConverterIdentifier + from pyrit.models.data_type_serializer import DataTypeSerializer + @dataclass class _WordDocInjectionConfig: diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index f320248e2c..d24993a0c3 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -331,7 +331,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Decode the assistant's response from the generated token IDs assistant_response = cast( - str, + "str", self.tokenizer.decode(generated_tokens, skip_special_tokens=self.skip_special_tokens), ).strip() @@ -372,7 +372,7 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: # Apply the chat template to format and tokenize the messages tokenized_chat = cast( - BatchEncoding, + "BatchEncoding", self.tokenizer.apply_chat_template( messages, tokenize=True, diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index bc709fa4b7..3fa526f1fe 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -698,7 +698,7 @@ def _find_last_pending_tool_call(self, reply: Message) -> Optional[dict[str, Any continue if section.get("type") == "function_call": # Do NOT skip function_call even if status == "completed" — we still need to emit the output. - return cast(dict[str, Any], section) + return cast("dict[str, Any]", section) return None async def _execute_call_section(self, tool_call_section: dict[str, Any]) -> dict[str, Any]: diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index f585d8d518..60ebc72d2b 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -15,9 +15,9 @@ from pyrit.memory import CentralMemory from pyrit.models import SeedAttackGroup, SeedGroup -from pyrit.models.seeds.seed import Seed if TYPE_CHECKING: + from pyrit.models.seeds.seed import Seed from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy # Key used when seed_groups are provided directly (not from a named dataset) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 7d9b3c3a3f..751a4b503d 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -19,7 +19,6 @@ from pyrit.common import REQUIRED_VALUE, apply_defaults from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack -from pyrit.identifiers import TargetIdentifier from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry from pyrit.models import AttackResult @@ -35,6 +34,7 @@ if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig + from pyrit.identifiers import TargetIdentifier from pyrit.models import SeedAttackGroup logger = logging.getLogger(__name__) diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index d1f1cdceb6..8be9b19e51 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -11,6 +11,8 @@ It also provides ScenarioCompositeStrategy for representing composed attack strategies. """ +from __future__ import annotations + from enum import Enum from typing import List, Sequence, Set, TypeVar diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index ea8ddf41df..e7a9584e25 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -3,7 +3,7 @@ import logging import os -from typing import Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message @@ -12,7 +12,6 @@ AttackAdversarialConfig, AttackScoringConfig, ) -from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.executor.attack.multi_turn.red_teaming import RedTeamingAttack from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.models import SeedAttackGroup, SeedObjective @@ -33,6 +32,9 @@ TrueFalseScorer, ) +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_strategy import AttackStrategy + logger = logging.getLogger(__name__) diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index 2afc2e4b29..c05f39d7d9 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -4,7 +4,7 @@ import logging import os from pathlib import Path -from typing import Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional from pyrit.common import apply_defaults from pyrit.common.path import ( @@ -21,7 +21,6 @@ AttackAdversarialConfig, AttackScoringConfig, ) -from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.models import SeedAttackGroup, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget from pyrit.scenario.core.atomic_attack import AtomicAttack @@ -40,6 +39,9 @@ TrueFalseScorer, ) +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_strategy import AttackStrategy + logger = logging.getLogger(__name__) PERSUASION_DECEPTION_PATH = Path(EXECUTOR_RED_TEAM_PATH, "persuasion_deception").resolve() diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index b85242bcd1..6bd7326ce8 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -12,7 +12,7 @@ import logging import os from inspect import signature -from typing import Any, List, Optional, Sequence, Type, TypeVar +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Type, TypeVar from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message @@ -28,7 +28,6 @@ AttackConverterConfig, AttackScoringConfig, ) -from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.models import SeedAttackGroup, SeedObjective from pyrit.prompt_converter import ( AnsiAttackConverter, @@ -77,6 +76,9 @@ TrueFalseScoreAggregator, ) +if TYPE_CHECKING: + from pyrit.executor.attack.core.attack_strategy import AttackStrategy + AttackStrategyT = TypeVar("AttackStrategyT", bound="AttackStrategy[Any, Any]") logger = logging.getLogger(__name__) diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index 50f566effa..6d31341093 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -3,14 +3,16 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ScorerIdentifier -from pyrit.models import MessagePiece, Score, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator +if TYPE_CHECKING: + from pyrit.identifiers import ScorerIdentifier + from pyrit.models import MessagePiece, Score, UnvalidatedScore + from pyrit.prompt_target import PromptChatTarget + class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): """ diff --git a/pyrit/score/float_scale/video_float_scale_scorer.py b/pyrit/score/float_scale/video_float_scale_scorer.py index fc4ff9454f..d858e53434 100644 --- a/pyrit/score/float_scale/video_float_scale_scorer.py +++ b/pyrit/score/float_scale/video_float_scale_scorer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from pyrit.identifiers import ScorerIdentifier from pyrit.models import MessagePiece, Score @@ -10,10 +10,12 @@ FloatScaleScorerByCategory, ) from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer -from pyrit.score.score_aggregator_result import ScoreAggregatorResult from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.video_scorer import _BaseVideoScorer +if TYPE_CHECKING: + from pyrit.score.score_aggregator_result import ScoreAggregatorResult + class VideoFloatScaleScorer( FloatScaleScorer, diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 6765d907e1..27e57ee095 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -39,10 +39,10 @@ ) from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.prompt_target.batch_helper import batch_task_async -from pyrit.score.scorer_prompt_validator import ScorerPromptValidator if TYPE_CHECKING: from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics + from pyrit.score.scorer_prompt_validator import ScorerPromptValidator logger = logging.getLogger(__name__) diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index 7c7a2e9b80..2b00ec673f 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -7,15 +7,12 @@ import logging import time from dataclasses import dataclass -from pathlib import Path -from typing import List, Optional, Tuple, cast +from typing import TYPE_CHECKING, List, Optional, Tuple, cast import numpy as np from scipy.stats import ttest_1samp from pyrit.common.path import SCORER_EVALS_PATH -from pyrit.models import Message -from pyrit.score import Scorer from pyrit.score.scorer_evaluation.human_labeled_dataset import ( HarmHumanLabeledEntry, HumanLabeledDataset, @@ -38,6 +35,12 @@ ) from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +if TYPE_CHECKING: + from pathlib import Path + + from pyrit.models import Message + from pyrit.score import Scorer + logger = logging.getLogger(__name__) # Standard column names for evaluation datasets diff --git a/pyrit/score/scorer_evaluation/scorer_metrics.py b/pyrit/score/scorer_evaluation/scorer_metrics.py index b88eee2f96..422e5804df 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics.py @@ -5,14 +5,15 @@ import json from dataclasses import asdict, dataclass, field -from pathlib import Path from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, Union -import numpy as np - from pyrit.common.utils import verify_and_resolve_path if TYPE_CHECKING: + from pathlib import Path + + import numpy as np + from pyrit.identifiers import ScorerIdentifier from pyrit.models.harm_definition import HarmDefinition diff --git a/pyrit/score/true_false/question_answer_scorer.py b/pyrit/score/true_false/question_answer_scorer.py index 1717c8c44f..96ff3cbbb5 100644 --- a/pyrit/score/true_false/question_answer_scorer.py +++ b/pyrit/score/true_false/question_answer_scorer.py @@ -3,9 +3,8 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ScorerIdentifier from pyrit.models import MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( @@ -14,6 +13,9 @@ ) from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +if TYPE_CHECKING: + from pyrit.identifiers import ScorerIdentifier + class QuestionAnswerScorer(TrueFalseScorer): """ diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 2de626e2a1..ae114aae22 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -3,11 +3,8 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import ScorerIdentifier -from pyrit.models import MessagePiece, Score, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -15,6 +12,11 @@ ) from pyrit.score.true_false.true_false_scorer import TrueFalseScorer +if TYPE_CHECKING: + from pyrit.identifiers import ScorerIdentifier + from pyrit.models import MessagePiece, Score, UnvalidatedScore + from pyrit.prompt_target import PromptChatTarget + class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): """ diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index 93a1dcfcc6..bf1c017dde 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -3,13 +3,10 @@ from __future__ import annotations -import pathlib -from typing import Optional +from typing import TYPE_CHECKING, Optional from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.common.utils import verify_and_resolve_path -from pyrit.models import MessagePiece, Score, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.self_ask_true_false_scorer import SelfAskTrueFalseScorer from pyrit.score.true_false.true_false_score_aggregator import ( @@ -17,6 +14,12 @@ TrueFalseScoreAggregator, ) +if TYPE_CHECKING: + import pathlib + + from pyrit.models import MessagePiece, Score, UnvalidatedScore + from pyrit.prompt_target import PromptChatTarget + class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): """ diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 999cf77bd6..1b3365239a 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -296,7 +296,6 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: ValueError: If an initializer name is not found in the registry. """ from pyrit.registry import InitializerRegistry - from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer if not self._initializer_configs: return [] diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index a9e4db4f9b..f73faaa178 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -3,17 +3,19 @@ import os import uuid -from typing import Generator, MutableSequence, Sequence +from typing import TYPE_CHECKING, Generator, MutableSequence, Sequence import pytest from pyrit.memory import AzureSQLMemory, EmbeddingDataEntry, PromptMemoryEntry -from pyrit.memory.memory_models import Base from pyrit.models import MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget from unit.mocks import get_azure_sql_memory, get_sample_conversation_entries +if TYPE_CHECKING: + from pyrit.memory.memory_models import Base + @pytest.fixture def memory_interface() -> Generator[AzureSQLMemory, None, None]: From ff8fae13adf7d5a1b12124bc3e94856b86fa7944 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 25 Feb 2026 05:31:06 -0800 Subject: [PATCH 2/5] Merge origin/main, fix TC001 in seed_attack_technique_group.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/seed_attack_technique_group.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyrit/models/seeds/seed_attack_technique_group.py b/pyrit/models/seeds/seed_attack_technique_group.py index ec5db2b822..02babf8678 100644 --- a/pyrit/models/seeds/seed_attack_technique_group.py +++ b/pyrit/models/seeds/seed_attack_technique_group.py @@ -11,9 +11,11 @@ from __future__ import annotations -from typing import Any, Dict, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Sequence, Union + +if TYPE_CHECKING: + from pyrit.models.seeds.seed import Seed -from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_group import SeedGroup From 160d4bf2e1d41321aef52c6a743242e01bd489b0 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 25 Feb 2026 05:42:42 -0800 Subject: [PATCH 3/5] FIX ruff format for 2 files Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py | 4 +++- pyrit/prompt_target/openai/openai_video_target.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 5bb24e1951..89a6150191 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -129,7 +129,9 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> List[Dict[st if response.status_code == 200: if file_type in FILE_TYPE_HANDLERS: if file_type == "json": - return cast("List[Dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))) + return cast( + "List[Dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)) + ) return cast( "List[Dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 700d338aea..276bbfc2c7 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -93,7 +93,9 @@ def __init__( """ super().__init__(**kwargs) - self._n_seconds: VideoSeconds = cast("VideoSeconds", str(n_seconds)) if isinstance(n_seconds, int) else n_seconds + self._n_seconds: VideoSeconds = ( + cast("VideoSeconds", str(n_seconds)) if isinstance(n_seconds, int) else n_seconds + ) self._validate_duration() self._size: VideoSize = self._validate_resolution(resolution_dimensions=resolution_dimensions) From 10dd15a638a85706a7ad2aa595658a662a9e6b9d Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 25 Feb 2026 06:20:09 -0800 Subject: [PATCH 4/5] FIX Mock tokenizer in unit test to avoid HuggingFace network call The chatml_tokenizer_normalizer fixture was calling AutoTokenizer.from_pretrained() which requires network access to HuggingFace. Replaced with a mock that simulates ChatML template formatting, making the test fully offline. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../test_chat_normalizer_tokenizer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py b/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py index 0086f162ea..81122fcc57 100644 --- a/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py +++ b/tests/unit/message_normalizer/test_chat_normalizer_tokenizer.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch import pytest -from transformers import AutoTokenizer from pyrit.message_normalizer import TokenizerTemplateNormalizer from pyrit.models import Message, MessagePiece @@ -116,8 +115,18 @@ class TestNormalizeStringAsync: @pytest.fixture def chatml_tokenizer_normalizer(self): - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - return TokenizerTemplateNormalizer(tokenizer=tokenizer) + def _apply_chatml_template(messages, tokenize=False, add_generation_prompt=False): + """Simulate ChatML template formatting.""" + result = "" + for msg in messages: + result += f"<|{msg['role']}|>\n{msg['content']}\n" + if add_generation_prompt: + result += "<|assistant|>\n" + return result + + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.side_effect = _apply_chatml_template + return TokenizerTemplateNormalizer(tokenizer=mock_tokenizer) @pytest.mark.asyncio async def test_normalize_chatml(self, chatml_tokenizer_normalizer: TokenizerTemplateNormalizer): From 4a06e4eccdccae0581887149fc99ddd12787e9d3 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 26 Feb 2026 11:03:52 -0800 Subject: [PATCH 5/5] Address review: move logging out of TYPE_CHECKING, move PromptTarget to TYPE_CHECKING - scorer.py: Move PromptChatTarget/PromptTarget to TYPE_CHECKING block (only used in annotations, file has 'from __future__ import annotations') - workflow_strategy.py + 4 other files: Move 'import logging' back to runtime with noqa:TC003 (stdlib module, no import cost concern) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/executor/attack/core/attack_strategy.py | 3 +-- .../attack/multi_turn/multi_turn_attack_strategy.py | 3 +-- .../attack/single_turn/single_turn_attack_strategy.py | 3 +-- pyrit/executor/promptgen/core/prompt_generator_strategy.py | 6 ++---- pyrit/executor/workflow/core/workflow_strategy.py | 6 ++---- pyrit/score/scorer.py | 2 +- 6 files changed, 8 insertions(+), 15 deletions(-) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 99596ce6a5..dac86d10aa 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -4,6 +4,7 @@ from __future__ import annotations import dataclasses +import logging # noqa: TC003 import time from abc import ABC from dataclasses import dataclass, field @@ -28,8 +29,6 @@ ) if TYPE_CHECKING: - import logging - from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.prompt_target import PromptTarget diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 58c9b6bbe9..5620c90032 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -3,6 +3,7 @@ from __future__ import annotations +import logging # noqa: TC003 import uuid from abc import ABC from dataclasses import dataclass, field @@ -17,8 +18,6 @@ ) if TYPE_CHECKING: - import logging - from pyrit.models import ( Message, Score, diff --git a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py index a4ae81f476..508f83db1c 100644 --- a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py +++ b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py @@ -3,6 +3,7 @@ from __future__ import annotations +import logging # noqa: TC003 import uuid from abc import ABC from dataclasses import dataclass, field @@ -14,8 +15,6 @@ from pyrit.models import AttackResult if TYPE_CHECKING: - import logging - from pyrit.prompt_target import PromptTarget diff --git a/pyrit/executor/promptgen/core/prompt_generator_strategy.py b/pyrit/executor/promptgen/core/prompt_generator_strategy.py index e0af009924..29068611ba 100644 --- a/pyrit/executor/promptgen/core/prompt_generator_strategy.py +++ b/pyrit/executor/promptgen/core/prompt_generator_strategy.py @@ -3,9 +3,10 @@ from __future__ import annotations +import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import Optional, TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -16,9 +17,6 @@ ) from pyrit.models import StrategyResult -if TYPE_CHECKING: - import logging - PromptGeneratorStrategyContextT = TypeVar("PromptGeneratorStrategyContextT", bound="PromptGeneratorStrategyContext") PromptGeneratorStrategyResultT = TypeVar("PromptGeneratorStrategyResultT", bound="PromptGeneratorStrategyResult") diff --git a/pyrit/executor/workflow/core/workflow_strategy.py b/pyrit/executor/workflow/core/workflow_strategy.py index ab937768e4..80dc3a298d 100644 --- a/pyrit/executor/workflow/core/workflow_strategy.py +++ b/pyrit/executor/workflow/core/workflow_strategy.py @@ -3,9 +3,10 @@ from __future__ import annotations +import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import Optional, TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -17,9 +18,6 @@ ) from pyrit.models import StrategyResult -if TYPE_CHECKING: - import logging - WorkflowContextT = TypeVar("WorkflowContextT", bound="WorkflowContext") WorkflowResultT = TypeVar("WorkflowResultT", bound="WorkflowResult") diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 08d3c6c117..0680d243d4 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -34,12 +34,12 @@ ScoreType, UnvalidatedScore, ) -from pyrit.prompt_target import PromptChatTarget, PromptTarget # noqa: TC001 from pyrit.prompt_target.batch_helper import batch_task_async if TYPE_CHECKING: from collections.abc import Sequence + from pyrit.prompt_target import PromptChatTarget, PromptTarget from pyrit.score.scorer_evaluation.metrics_type import RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_evaluator import ( ScorerEvalDatasetFiles,