From c029636d110e885bc2e1333c408273f70ab3a450 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 22 May 2026 18:44:19 -0700 Subject: [PATCH] Simplifying scenario techniques --- .../instructions/scenarios.instructions.md | 48 ++ doc/scanner/airt.ipynb | 12 +- doc/scanner/airt.py | 12 +- pyrit/cli/pyrit_scan.py | 2 +- pyrit/registry/__init__.py | 2 - pyrit/registry/object_registries/__init__.py | 2 - .../attack_technique_registry.py | 209 ++------ pyrit/scenario/core/__init__.py | 6 - .../scenario/core/attack_technique_factory.py | 267 +++++++++-- pyrit/scenario/core/scenario.py | 25 +- pyrit/scenario/core/scenario_techniques.py | 196 -------- pyrit/scenario/scenarios/airt/cyber.py | 19 +- pyrit/scenario/scenarios/airt/leakage.py | 38 +- .../scenario/scenarios/airt/rapid_response.py | 13 +- .../scenarios/benchmark/adversarial.py | 47 +- pyrit/setup/initializers/__init__.py | 2 +- .../setup/initializers/components/__init__.py | 2 +- .../components/scenario_techniques.py | 130 +++++ .../initializers/components/scenarios.py | 140 ------ tests/unit/backend/test_scenario_service.py | 12 +- .../test_attack_technique_registry.py | 172 ++++--- tests/unit/scenario/test_adversarial.py | 94 ++-- .../scenario/test_attack_technique_factory.py | 80 ++-- tests/unit/scenario/test_cyber.py | 67 +-- tests/unit/scenario/test_leakage_scenario.py | 25 +- tests/unit/scenario/test_rapid_response.py | 450 +++++------------- .../test_scenario_strategy_invariants.py | 13 +- ...> test_scenario_techniques_initializer.py} | 127 ++--- 28 files changed, 1011 insertions(+), 1201 deletions(-) delete mode 100644 pyrit/scenario/core/scenario_techniques.py create mode 100644 pyrit/setup/initializers/components/scenario_techniques.py delete mode 100644 pyrit/setup/initializers/components/scenarios.py rename tests/unit/setup/{test_scenarios_initializer.py => test_scenario_techniques_initializer.py} (70%) diff --git a/.github/instructions/scenarios.instructions.md b/.github/instructions/scenarios.instructions.md index 867375f8ab..20eae682c4 100644 --- a/.github/instructions/scenarios.instructions.md +++ b/.github/instructions/scenarios.instructions.md @@ -145,11 +145,59 @@ get atomic-attack construction **for free** — no override needed. The default implementation: 1. Calls `self._get_attack_technique_factories()` to get name→factory mapping + (defaults to reading every `AttackTechniqueFactory` registered in the + `AttackTechniqueRegistry` singleton) 2. Iterates over every (technique × dataset) pair from `self._dataset_config` 3. Calls `factory.create()` with `objective_target` and conditional scorer override 4. Uses `self._build_display_group()` for user-facing grouping 5. Builds `AtomicAttack` with unique `atomic_attack_name` = `"{technique}_{dataset}"` +### AttackTechniqueFactory + +Techniques are described by `AttackTechniqueFactory` instances rather than a separate spec +dataclass. The canonical catalog lives in +`pyrit.setup.initializers.components.scenario_techniques` (`build_scenario_technique_factories()`) +and is loaded into the registry by `ScenarioTechniqueInitializer`. + +```python +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + +AttackTechniqueFactory( + name="prompt_sending", # REQUIRED — must match the strategy enum value + attack_class=PromptSendingAttack, + strategy_tags=["core", "single_turn", "default"], + attack_kwargs={"max_turns": 5}, + adversarial_chat=None, # mutually exclusive with adversarial_config + adversarial_config=None, + seed_technique=None, + uses_adversarial=None, # None = auto-derive from attack signature/seeds + scorer_override_policy=ScorerOverridePolicy.WARN, +) +``` + +Key points: +- `name` is required and must match the strategy enum value the scenario looks up. +- `strategy_tags` on the factory drives `TagQuery` filters used by + `AttackTechniqueRegistry.build_strategy_class_from_factories(...)`. This is **distinct** + from the per-entry `tags` argument passed to `registry.register_technique(...)`. +- `uses_adversarial` is auto-derived from the attack class signature (presence of + `attack_adversarial_config`) and seed shape; pass `False` explicitly to opt out, or + `True` to force opt-in. +- `kwargs` are validated against the attack class constructor signature at + factory-construction time, so typos fail loudly and early. + +### Registering factories + +```python +registry = AttackTechniqueRegistry.get_registry_singleton() +registry.register_from_factories(build_scenario_technique_factories()) +``` + +`register_from_factories` reads `factory.strategy_tags` to populate the per-entry tags used +by the registry. Tests that exercise scenarios should reset both `AttackTechniqueRegistry` +and `TargetRegistry` and re-register a mock `adversarial_chat` so the catalog builder +resolves without falling back to `OpenAIChatTarget`. + ### Customization hooks (no need to override `_get_atomic_attacks_async`): - **`_get_attack_technique_factories()`** — override to add/remove/replace factories - **`_build_display_group()`** — override to change grouping (default: by technique) diff --git a/doc/scanner/airt.ipynb b/doc/scanner/airt.ipynb index 0d7d3fb494..ba62a19214 100644 --- a/doc/scanner/airt.ipynb +++ b/doc/scanner/airt.ipynb @@ -74,12 +74,12 @@ "pyrit_scan airt.rapid_response \\\n", " --initializers target load_default_datasets \\\n", " --target openai_chat \\\n", - " --strategies prompt_sending \\\n", + " --strategies role_play \\\n", " --dataset-names airt_hate \\ \n", " --max-dataset-size 1\n", "```\n", "\n", - "**Available strategies:** ALL, DEFAULT, SINGLE_TURN, MULTI_TURN, prompt_sending, role_play, many_shot, tap" + "**Available strategies:** ALL, DEFAULT, SINGLE_TURN, MULTI_TURN, role_play, many_shot, tap" ] }, { @@ -111,7 +111,7 @@ "scenario = RapidResponse()\n", "await scenario.initialize_async( # type: ignore\n", " objective_target=objective_target,\n", - " scenario_strategies=[RapidResponseStrategy.prompt_sending],\n", + " scenario_strategies=[RapidResponseStrategy.role_play],\n", " dataset_config=dataset_config,\n", ")\n", "\n", @@ -369,11 +369,11 @@ "pyrit_scan airt.cyber \\\n", " --initializers target load_default_datasets \\\n", " --target openai_chat \\\n", - " --strategies single_turn \\\n", + " --strategies multi_turn \\\n", " --max-dataset-size 1\n", "```\n", "\n", - "**Available strategies:** ALL, SINGLE_TURN, MULTI_TURN" + "**Available strategies:** ALL, MULTI_TURN, red_teaming" ] }, { @@ -405,7 +405,7 @@ "scenario = Cyber()\n", "await scenario.initialize_async( # type: ignore\n", " objective_target=objective_target,\n", - " scenario_strategies=[CyberStrategy.SINGLE_TURN],\n", + " scenario_strategies=[CyberStrategy.MULTI_TURN],\n", " dataset_config=dataset_config,\n", ")\n", "\n", diff --git a/doc/scanner/airt.py b/doc/scanner/airt.py index b3aa07fc24..f31468972d 100644 --- a/doc/scanner/airt.py +++ b/doc/scanner/airt.py @@ -42,12 +42,12 @@ # pyrit_scan airt.rapid_response \ # --initializers target load_default_datasets \ # --target openai_chat \ -# --strategies prompt_sending \ +# --strategies role_play \ # --dataset-names airt_hate \ # --max-dataset-size 1 # ``` # -# **Available strategies:** ALL, DEFAULT, SINGLE_TURN, MULTI_TURN, prompt_sending, role_play, many_shot, tap +# **Available strategies:** ALL, DEFAULT, SINGLE_TURN, MULTI_TURN, role_play, many_shot, tap # %% from pyrit.scenario.scenarios.airt import RapidResponse, RapidResponseStrategy @@ -57,7 +57,7 @@ scenario = RapidResponse() await scenario.initialize_async( # type: ignore objective_target=objective_target, - scenario_strategies=[RapidResponseStrategy.prompt_sending], + scenario_strategies=[RapidResponseStrategy.role_play], dataset_config=dataset_config, ) @@ -125,11 +125,11 @@ # pyrit_scan airt.cyber \ # --initializers target load_default_datasets \ # --target openai_chat \ -# --strategies single_turn \ +# --strategies multi_turn \ # --max-dataset-size 1 # ``` # -# **Available strategies:** ALL, SINGLE_TURN, MULTI_TURN +# **Available strategies:** ALL, MULTI_TURN, red_teaming # %% from pyrit.scenario.scenarios.airt import Cyber, CyberStrategy @@ -139,7 +139,7 @@ scenario = Cyber() await scenario.initialize_async( # type: ignore objective_target=objective_target, - scenario_strategies=[CyberStrategy.SINGLE_TURN], + scenario_strategies=[CyberStrategy.MULTI_TURN], dataset_config=dataset_config, ) diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 70ec9c2f4d..791634d71e 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -49,7 +49,7 @@ # Run rapid response with specific datasets and concurrency pyrit_scan airt.rapid_response --target openai_chat - --strategies prompt_sending --dataset-names airt_hate + --strategies role_play --dataset-names airt_hate --max-dataset-size 5 --max-concurrency 4 # Run multi-turn red team agent with labels for tracking diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 16527d00c3..cd94382b93 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -20,7 +20,6 @@ ) from pyrit.registry.object_registries import ( AttackTechniqueRegistry, - AttackTechniqueSpec, BaseInstanceRegistry, ConverterRegistry, RegistryEntry, @@ -49,6 +48,5 @@ "ScenarioRegistry", "ScorerRegistry", "TargetRegistry", - "AttackTechniqueSpec", "TagQuery", ] diff --git a/pyrit/registry/object_registries/__init__.py b/pyrit/registry/object_registries/__init__.py index 7d5c82c14a..0a43a5af2f 100644 --- a/pyrit/registry/object_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -13,7 +13,6 @@ from pyrit.registry.object_registries.attack_technique_registry import ( AttackTechniqueRegistry, - AttackTechniqueSpec, ) from pyrit.registry.object_registries.base_instance_registry import ( BaseInstanceRegistry, @@ -42,5 +41,4 @@ "ConverterRegistry", "ScorerRegistry", "TargetRegistry", - "AttackTechniqueSpec", ] diff --git a/pyrit/registry/object_registries/attack_technique_registry.py b/pyrit/registry/object_registries/attack_technique_registry.py index 55e15a02e9..e7fbbde0eb 100644 --- a/pyrit/registry/object_registries/attack_technique_registry.py +++ b/pyrit/registry/object_registries/attack_technique_registry.py @@ -4,25 +4,22 @@ """ AttackTechniqueRegistry — Singleton registry of reusable attack technique factories. -Scenarios and initializers register technique factories (capturing technique-specific -config). Scenarios retrieve factories via ``get_factories()`` and call -``factory.create()`` with the scenario's objective target and scorer. +Scenarios and initializers register self-describing ``AttackTechniqueFactory`` +instances. Scenarios retrieve factories via ``get_factories()`` and filter +them in-place (e.g. by ``factory.uses_adversarial`` or strategy tags) before +calling ``factory.create()`` with the scenario's objective target and scorer. """ from __future__ import annotations -import inspect import logging -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from pyrit.registry.object_registries.base_instance_registry import ( BaseInstanceRegistry, ) if TYPE_CHECKING: - from pyrit.models import SeedAttackTechniqueGroup - from pyrit.prompt_target import PromptTarget from pyrit.registry.tag_query import TagQuery from pyrit.scenario import AttackTechniqueFactory from pyrit.scenario.core.attack_technique_factory import ScorerOverridePolicy @@ -30,90 +27,12 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class AttackTechniqueSpec: - """ - Declarative definition of an attack technique. - - The registry converts specs into ``AttackTechniqueFactory`` instances. - A minimal spec only needs ``name`` and ``attack_class``:: - - AttackTechniqueSpec(name="prompt_sending", attack_class=PromptSendingAttack) - - Use ``extra_kwargs`` for constructor arguments specific to a particular - attack class (as opposed to common arguments like ``objective_target`` - and ``attack_scoring_config``, which the factory injects automatically):: - - AttackTechniqueSpec( - name="role_play", - attack_class=RolePlayAttack, - strategy_tags=["core", "single_turn"], - extra_kwargs={"role_play_definition_path": RolePlayPaths.MOVIE_SCRIPT.value}, - ) - - Attacks that need an adversarial chat target should set - ``adversarial_chat`` (resolved target) or ``adversarial_chat_key`` - (deferred ``TargetRegistry`` key resolved at runtime by - ``build_scenario_techniques()``). These are mutually exclusive. - The registry automatically injects an ``AttackAdversarialConfig`` when - the attack class accepts one and ``adversarial_chat`` is set. - - Args: - name: Registry name (must match the strategy enum value). - attack_class: The ``AttackStrategy`` subclass (e.g. - ``PromptSendingAttack``, ``TreeOfAttacksWithPruningAttack``). - strategy_tags: Tags controlling which ``ScenarioStrategy`` aggregates - include this technique (e.g. ``"single_turn"``, ``"multi_turn"``). - adversarial_chat: Live adversarial chat target for multi-turn attacks. - Part of technique identity. Mutually exclusive with - ``adversarial_chat_key``. - adversarial_chat_key: Deferred ``TargetRegistry`` key resolved into - ``adversarial_chat`` at runtime. Use in static spec catalogs - where the target isn't available yet. - extra_kwargs: Attack-class-specific keyword arguments forwarded to - the constructor, e.g. ``{"tree_width": 5}`` for - ``TreeOfAttacksWithPruningAttack``. Must not contain - ``attack_adversarial_config`` (use ``adversarial_chat``) or - factory-injected args (``objective_target``, - ``attack_scoring_config``). - seed_technique: Optional ``SeedAttackTechniqueGroup`` to attach to - the created ``AttackTechnique``. Seeds are merged into each - ``SeedAttackGroup`` at execution time via ``with_technique()``. - """ - - name: str - attack_class: type - strategy_tags: list[str] = field(default_factory=list) - adversarial_chat: PromptTarget | None = field(default=None) - adversarial_chat_key: str | None = None - extra_kwargs: dict[str, Any] = field(default_factory=dict) - seed_technique: SeedAttackTechniqueGroup | None = None - - @property - def tags(self) -> list[str]: - """Return strategy_tags as the Taggable interface.""" - return self.strategy_tags - - def __post_init__(self) -> None: - """ - Validate mutually exclusive fields. - - Raises: - ValueError: If both adversarial_chat and adversarial_chat_key are set. - """ - if self.adversarial_chat and self.adversarial_chat_key: - raise ValueError( - f"Technique spec '{self.name}' sets both adversarial_chat and " - f"adversarial_chat_key — these are mutually exclusive." - ) - - class AttackTechniqueRegistry(BaseInstanceRegistry["AttackTechniqueFactory"]): """ Singleton registry of reusable attack technique factories. - Scenarios and initializers register technique factories (capturing - technique-specific config). Scenarios retrieve factories via + Scenarios and initializers register self-describing + ``AttackTechniqueFactory`` instances. Scenarios retrieve factories via ``get_factories()`` and call ``factory.create()`` with the scenario's objective target and scorer. """ @@ -152,6 +71,9 @@ def get_factories(self) -> dict[str, AttackTechniqueFactory]: """ Return all registered factories as a name→factory dict. + Callers filter the result in-place using factory properties (e.g. + ``factory.uses_adversarial`` or ``factory.strategy_tags``). + Returns: dict[str, AttackTechniqueFactory]: Mapping of technique name to factory. """ @@ -163,29 +85,26 @@ def scorer_override_policy(self) -> ScorerOverridePolicy: return self._scorer_override_policy @staticmethod - def build_strategy_class_from_specs( + def build_strategy_class_from_factories( *, class_name: str, - specs: list[AttackTechniqueSpec], + factories: list[AttackTechniqueFactory], aggregate_tags: dict[str, TagQuery], ) -> type: """ - Build a ``ScenarioStrategy`` enum subclass dynamically from technique specs. + Build a ``ScenarioStrategy`` enum subclass dynamically from technique factories. Creates an enum class with: - An ``ALL`` aggregate member (always included). - Additional aggregate members from ``aggregate_tags`` keys. - - One technique member per spec, with tags from the spec. + - One technique member per factory, with tags from the factory. Each aggregate maps to a :class:`TagQuery` that determines which - technique specs belong to it. - - This reads from the **spec list** (pure data), not from the mutable - registry. This ensures deterministic output regardless of registry state. + technique factories belong to it. Args: class_name: Name for the generated enum class. - specs: Technique specifications to include as enum members. + factories: Technique factories to include as enum members. aggregate_tags: Maps aggregate member names to a :class:`TagQuery` that selects which techniques belong to the aggregate. An ``ALL`` aggregate (expanding to all techniques) is always added. @@ -204,11 +123,13 @@ def build_strategy_class_from_specs( for agg_name in aggregate_tags: members[agg_name.upper()] = (agg_name, {agg_name}) - # Technique members from specs — assign aggregate tags based on TagQuery matching - for spec in specs: - spec_tags = set(spec.strategy_tags) - matched_agg_tags = {agg_name for agg_name, query in aggregate_tags.items() if query.matches(spec_tags)} - members[spec.name] = (spec.name, spec_tags | matched_agg_tags) + # Technique members from factories — assign aggregate tags based on TagQuery matching + for factory in factories: + factory_tags = set(factory.strategy_tags) + matched_agg_tags = { + agg_name for agg_name, query in aggregate_tags.items() if query.matches(factory_tags) + } + members[factory.name] = (factory.name, factory_tags | matched_agg_tags) # Build the enum class dynamically strategy_cls = ScenarioStrategy(class_name, members) @@ -222,91 +143,27 @@ def _get_aggregate_tags(cls: type) -> set[str]: return strategy_cls # type: ignore[ty:invalid-return-type] - @staticmethod - def build_factory_from_spec( - spec: AttackTechniqueSpec, - *, - scorer_override_policy: ScorerOverridePolicy | None = None, - ) -> AttackTechniqueFactory: - """ - Build an ``AttackTechniqueFactory`` from an ``AttackTechniqueSpec``. - - The adversarial chat target (``spec.adversarial_chat``) is stored on the - factory as an ``AttackAdversarialConfig``. The factory injects it into - the attack constructor at ``create()`` time if the attack class accepts - ``attack_adversarial_config``. - - Args: - spec: The technique specification. - scorer_override_policy: Policy for incompatible scorer overrides. - Defaults to WARN when None. - - Returns: - AttackTechniqueFactory: A factory ready for registration. - - Raises: - ValueError: If ``extra_kwargs`` contains the reserved key - ``attack_adversarial_config``. - """ - from pyrit.executor.attack import AttackAdversarialConfig - from pyrit.scenario import AttackTechniqueFactory - from pyrit.scenario.core.attack_technique_factory import ScorerOverridePolicy - - scorer_override_policy = scorer_override_policy or ScorerOverridePolicy.WARN - - if "attack_adversarial_config" in spec.extra_kwargs: - raise ValueError( - f"Spec '{spec.name}': 'attack_adversarial_config' must not appear in extra_kwargs. " - "Set spec.adversarial_chat instead." - ) - - kwargs: dict[str, Any] = dict(spec.extra_kwargs) - - adversarial_config = ( - AttackAdversarialConfig(target=spec.adversarial_chat) if spec.adversarial_chat is not None else None - ) - - return AttackTechniqueFactory( - attack_class=spec.attack_class, # type: ignore[ty:invalid-argument-type] - attack_kwargs=kwargs or None, - adversarial_config=adversarial_config, - seed_technique=spec.seed_technique, - scorer_override_policy=scorer_override_policy, - ) - - @staticmethod - def _accepts_adversarial(attack_class: type) -> bool: - """ - Check if an attack class accepts ``attack_adversarial_config``. - - Returns: - bool: Whether the parameter is present in the class constructor. - """ - sig = inspect.signature(attack_class.__init__) - return "attack_adversarial_config" in sig.parameters - - def register_from_specs( + def register_from_factories( self, - specs: list[AttackTechniqueSpec], + factories: list[AttackTechniqueFactory], ) -> None: """ - Build factories from specs and register them. + Register a list of factories under their ``name``. Per-name idempotent: existing entries are not overwritten. Args: - specs: Technique specifications to register. Each spec is - self-contained: the adversarial chat target (if any) is - declared on the spec itself via ``spec.adversarial_chat``. + factories: Self-describing factories to register. Each factory's + ``name`` and ``strategy_tags`` properties are used directly. """ - for spec in specs: - if spec.name not in self: - factory = self.build_factory_from_spec(spec, scorer_override_policy=self._scorer_override_policy) - tags: dict[str, str] = dict.fromkeys(spec.strategy_tags, "") + for factory in factories: + if factory.name not in self: + tags: dict[str, str] = dict.fromkeys(factory.strategy_tags, "") self.register_technique( - name=spec.name, + name=factory.name, factory=factory, tags=tags, ) logger.debug("Technique registration complete (%d total in registry)", len(self)) + diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index fede5a1da8..7b50cef237 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -11,10 +11,6 @@ from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target, get_default_scorer_target -from pyrit.scenario.core.scenario_techniques import ( - SCENARIO_TECHNIQUES, - register_scenario_techniques, -) __all__ = [ "AtomicAttack", @@ -23,13 +19,11 @@ "BaselineAttackPolicy", "DatasetConfiguration", "EXPLICIT_SEED_GROUPS_KEY", - "SCENARIO_TECHNIQUES", "Parameter", "Scenario", "ScenarioCompositeStrategy", "ScenarioStrategy", "ScorerOverridePolicy", - "register_scenario_techniques", "get_default_scorer_target", "get_default_adversarial_target", ] diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index a080c1550a..2ec0dc3ff8 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -2,11 +2,18 @@ # Licensed under the MIT license. """ -AttackTechniqueFactory — Deferred construction of AttackTechnique instances. - -Captures technique-specific configuration at registration time and produces -fresh, fully-constructed attacks when scenario-specific params (objective target, -scorer) become available. +AttackTechniqueFactory — Self-describing deferred constructor for AttackTechnique instances. + +Captures technique-specific configuration (name, strategy tags, attack class, +attack-class kwargs, optional adversarial chat, optional seed technique) at +construction time. Scenarios produce fresh, fully-constructed attacks by calling +``create()`` with scenario-specific params (objective target, scorer). + +The canonical place to register factories is the +``ScenarioTechniqueInitializer`` in +``pyrit.setup.initializers.components.scenario_techniques``. New initializers +register additional factories by calling +``AttackTechniqueRegistry.register_from_factories(...)``. """ from __future__ import annotations @@ -16,9 +23,13 @@ import sys import typing from enum import Enum +from pathlib import Path from typing import TYPE_CHECKING, Any, Union +from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier, Identifiable, build_seed_identifier +from pyrit.models import SeedAttackTechniqueGroup, SeedSimulatedConversation +from pyrit.models.seeds.seed_simulated_conversation import NextMessageSystemPromptPaths from pyrit.scenario.core.attack_technique import AttackTechnique if TYPE_CHECKING: @@ -28,7 +39,6 @@ AttackConverterConfig, AttackScoringConfig, ) - from pyrit.models import SeedAttackTechniqueGroup from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -44,12 +54,12 @@ class ScorerOverridePolicy(str, Enum): class AttackTechniqueFactory(Identifiable): """ - A factory that produces AttackTechnique instances on demand. + A self-describing factory that produces AttackTechnique instances on demand. - Captures technique-specific configuration (converters, adversarial config, - tree depth, etc.) at registration time. Produces fresh, fully-constructed - attacks by calling the real constructor with the captured params plus - scenario-specific objective_target and scoring config. + Captures technique-specific configuration (name, strategy tags, converters, + adversarial config, tree depth, etc.) at construction time. Produces fresh, + fully-constructed attacks by calling the real constructor with the captured + params plus scenario-specific objective_target and scoring config. Validates kwargs against the attack class constructor signature at construction time, catching typos and incompatible parameter names early. @@ -58,42 +68,191 @@ class AttackTechniqueFactory(Identifiable): def __init__( self, *, + name: str, attack_class: type[AttackStrategy[Any, Any]], + strategy_tags: list[str] | None = None, attack_kwargs: dict[str, Any] | None = None, + adversarial_chat: PromptTarget | None = None, adversarial_config: AttackAdversarialConfig | None = None, seed_technique: SeedAttackTechniqueGroup | None = None, + uses_adversarial: bool | None = None, scorer_override_policy: ScorerOverridePolicy = ScorerOverridePolicy.WARN, ) -> None: """ Initialize the factory with a technique-specific configuration. Args: + name: Registry name for this technique. Must match the strategy + enum value used by scenarios. attack_class: The AttackStrategy subclass to instantiate. + strategy_tags: Tags controlling which ``ScenarioStrategy`` + aggregates include this technique (e.g. ``"single_turn"``, + ``"multi_turn"``, ``"default"``). attack_kwargs: Keyword arguments to pass to the attack constructor. Must not include ``objective_target`` (provided at create time) - or ``attack_adversarial_config`` (use ``adversarial_config`` instead). - adversarial_config: Optional adversarial chat configuration. Stored - separately and injected into the attack at ``create()`` time if - the attack class accepts ``attack_adversarial_config``. Also - exposed via the ``adversarial_chat`` property for seed-technique - execution. - seed_technique: Optional technique seed group to attach to created techniques. - scorer_override_policy: What to do when a scenario's scorer is incompatible - with the attack's ``attack_scoring_config`` type annotation. Defaults to WARN. + or ``attack_adversarial_config`` (use ``adversarial_chat`` or + ``adversarial_config`` instead). + adversarial_chat: Convenience kwarg — adversarial chat target that + is wrapped into an ``AttackAdversarialConfig`` internally. + Mutually exclusive with ``adversarial_config``. + adversarial_config: Pre-built adversarial config. Mutually + exclusive with ``adversarial_chat``. Injected into the attack + at ``create()`` time if the attack class accepts + ``attack_adversarial_config``. + seed_technique: Optional technique seed group attached to created + techniques. + uses_adversarial: Whether this technique drives an adversarial + chat during execution. ``None`` auto-derives from the attack + class constructor signature and seed-technique shape. + Authors can override the derivation explicitly. + scorer_override_policy: What to do when a scenario's scorer is + incompatible with the attack's ``attack_scoring_config`` type + annotation. Defaults to WARN. Raises: TypeError: If any kwarg name is not a valid constructor parameter, or if the attack class constructor uses ``**kwargs``. - ValueError: If ``objective_target`` or ``attack_adversarial_config`` - is included in attack_kwargs. + ValueError: If both ``adversarial_chat`` and ``adversarial_config`` + are provided, if ``objective_target`` or + ``attack_adversarial_config`` is included in ``attack_kwargs``, + or if ``uses_adversarial=False`` while an adversarial chat is + wired. """ + if adversarial_chat is not None and adversarial_config is not None: + raise ValueError( + f"Factory '{name}': adversarial_chat and adversarial_config are mutually exclusive." + ) + + self._name = name self._attack_class = attack_class + self._strategy_tags = list(strategy_tags) if strategy_tags else [] self._attack_kwargs = dict(attack_kwargs) if attack_kwargs else {} - self._adversarial_config = adversarial_config + self._adversarial_config = adversarial_config or self._build_adversarial_config(adversarial_chat) + self._adversarial_config_was_explicit = self._adversarial_config is not None self._seed_technique = seed_technique self._scorer_override_policy = scorer_override_policy + self._uses_adversarial = ( + uses_adversarial if uses_adversarial is not None else self._derive_uses_adversarial() + ) + self._validate_kwargs() + self._validate_adversarial_flags() + + @classmethod + def with_simulated_conversation( + cls, + *, + name: str, + attack_class: type[AttackStrategy[Any, Any]] | None = None, + adversarial_chat_system_prompt_path: str | Path | None = None, + next_message_system_prompt_path: str | Path | None = None, + num_turns: int = 3, + strategy_tags: list[str] | None = None, + attack_kwargs: dict[str, Any] | None = None, + adversarial_chat: PromptTarget | None = None, + adversarial_config: AttackAdversarialConfig | None = None, + uses_adversarial: bool | None = None, + scorer_override_policy: ScorerOverridePolicy = ScorerOverridePolicy.WARN, + ) -> AttackTechniqueFactory: + """ + Alternative constructor that builds a ``SeedSimulatedConversation`` inline. + + Wraps a single ``SeedSimulatedConversation`` in a ``SeedAttackTechniqueGroup`` + and assigns it as ``seed_technique`` so callers don't have to construct + both manually. All other parameters are forwarded to ``__init__``. + + Args: + name: Registry name for this technique. When other defaults are used, + ``name`` also picks the canonical YAML at + ``EXECUTOR_SEED_PROMPT_PATH/red_teaming/{name}.yaml``. + attack_class: The AttackStrategy subclass to instantiate. Defaults to + ``PromptSendingAttack``. + adversarial_chat_system_prompt_path: Path to the YAML file containing + the adversarial chat system prompt for the simulated conversation. + Defaults to ``EXECUTOR_SEED_PROMPT_PATH/red_teaming/{name}.yaml``. + next_message_system_prompt_path: Optional path to the YAML file + containing the system prompt for generating a final user message + after the simulated conversation. Defaults to + ``NextMessageSystemPromptPaths.DIRECT.value``. + num_turns: Number of simulated conversation turns. Defaults to 3. + strategy_tags: Forwarded to ``__init__``. + attack_kwargs: Forwarded to ``__init__``. + adversarial_chat: Forwarded to ``__init__``. + adversarial_config: Forwarded to ``__init__``. + uses_adversarial: Forwarded to ``__init__``. + scorer_override_policy: Forwarded to ``__init__``. + + Returns: + AttackTechniqueFactory: A new factory whose ``seed_technique`` is the + wrapped simulated conversation. + """ + if attack_class is None: + from pyrit.executor.attack import PromptSendingAttack + + attack_class = PromptSendingAttack + if adversarial_chat_system_prompt_path is None: + adversarial_chat_system_prompt_path = Path(EXECUTOR_SEED_PROMPT_PATH) / "red_teaming" / f"{name}.yaml" + if next_message_system_prompt_path is None: + next_message_system_prompt_path = NextMessageSystemPromptPaths.DIRECT.value + + seed_technique = SeedAttackTechniqueGroup( + seeds=[ + SeedSimulatedConversation( + adversarial_chat_system_prompt_path=adversarial_chat_system_prompt_path, + next_message_system_prompt_path=next_message_system_prompt_path, + num_turns=num_turns, + ), + ], + ) + return cls( + name=name, + attack_class=attack_class, + strategy_tags=strategy_tags, + attack_kwargs=attack_kwargs, + adversarial_chat=adversarial_chat, + adversarial_config=adversarial_config, + seed_technique=seed_technique, + uses_adversarial=uses_adversarial, + scorer_override_policy=scorer_override_policy, + ) + + @staticmethod + def _build_adversarial_config( + adversarial_chat: PromptTarget | None, + ) -> AttackAdversarialConfig | None: + """Wrap a bare ``PromptTarget`` into an ``AttackAdversarialConfig``.""" + if adversarial_chat is None: + return None + from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig + + return AttackAdversarialConfig(target=adversarial_chat) + + def _derive_uses_adversarial(self) -> bool: + """ + Auto-derive ``uses_adversarial`` from the attack class signature and seed shape. + + Returns: + bool: ``True`` if the attack class accepts ``attack_adversarial_config`` + or the seed technique has a simulated conversation. + """ + sig = inspect.signature(self._attack_class.__init__) + if "attack_adversarial_config" in sig.parameters: + return True + return self._seed_technique is not None and self._seed_technique.has_simulated_conversation + + def _validate_adversarial_flags(self) -> None: + """ + Validate that ``uses_adversarial`` and ``adversarial_chat`` are coherent. + + Raises: + ValueError: If an adversarial chat is wired but ``uses_adversarial=False``. + """ + if not self._uses_adversarial and self._adversarial_config is not None: + raise ValueError( + f"Factory '{self._name}': adversarial_chat is set but uses_adversarial=False. " + f"A technique that doesn't use an adversarial chat must not have one wired." + ) def _validate_kwargs(self) -> None: """ @@ -113,7 +272,8 @@ def _validate_kwargs(self) -> None: raise ValueError("objective_target must not be in attack_kwargs — it is provided at create() time.") if "attack_adversarial_config" in self._attack_kwargs: raise ValueError( - "attack_adversarial_config must not be in attack_kwargs — use the adversarial_config parameter instead." + "attack_adversarial_config must not be in attack_kwargs — use adversarial_chat or " + "adversarial_config instead." ) sig = inspect.signature(self._attack_class.__init__) @@ -145,6 +305,21 @@ def _validate_kwargs(self) -> None: f"Valid parameters: {sorted(valid_params)}" ) + @property + def name(self) -> str: + """The registry name for this technique.""" + return self._name + + @property + def strategy_tags(self) -> list[str]: + """Tags controlling which ``ScenarioStrategy`` aggregates include this technique.""" + return list(self._strategy_tags) + + @property + def tags(self) -> list[str]: + """Alias for ``strategy_tags`` exposing the Taggable interface (used by ``TagQuery.filter``).""" + return list(self._strategy_tags) + @property def attack_class(self) -> type[AttackStrategy[Any, Any]]: """The attack strategy class this factory produces.""" @@ -160,6 +335,11 @@ def adversarial_chat(self) -> PromptTarget | None: """The adversarial chat target baked into this factory, or None.""" return self._adversarial_config.target if self._adversarial_config else None + @property + def uses_adversarial(self) -> bool: + """Whether this technique drives an adversarial chat during execution.""" + return self._uses_adversarial + def create( self, *, @@ -201,9 +381,24 @@ def create( A fresh AttackTechnique with a newly-constructed attack strategy. Raises: - ValueError: If ``scorer_override_policy`` is RAISE and the override - config is incompatible with the attack's type annotation. + ValueError: If ``attack_adversarial_config_override`` is supplied but + the factory already has an adversarial config baked in at + construction time, or if ``scorer_override_policy`` is RAISE and + the override config is incompatible with the attack's type annotation. """ + if attack_adversarial_config_override is not None and self._adversarial_config_was_explicit: + raise ValueError( + f"Factory '{self._name}': adversarial config was baked in at construction; " + f"cannot supply attack_adversarial_config_override." + ) + + if ( + self._uses_adversarial + and self._adversarial_config is None + and attack_adversarial_config_override is None + ): + self._adversarial_config = self._resolve_default_adversarial_config() + kwargs = dict(self._attack_kwargs) kwargs["objective_target"] = objective_target @@ -224,6 +419,14 @@ def create( attack = self._attack_class(**kwargs) return AttackTechnique(attack=attack, seed_technique=self._seed_technique) + @staticmethod + def _resolve_default_adversarial_config() -> AttackAdversarialConfig: + """Lazily resolve the default adversarial chat target and wrap it in a config.""" + from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig + from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target + + return AttackAdversarialConfig(target=get_default_adversarial_target()) + def _get_accepted_params(self) -> set[str]: """Return the set of keyword parameter names accepted by the attack class constructor.""" sig = inspect.signature(self._attack_class.__init__) @@ -381,19 +584,23 @@ def _build_identifier(self) -> ComponentIdentifier: """ Build the behavioral identity for this factory. - Includes the attack class name and kwargs with their serialized values - so that factories with different configurations produce different hashes. - When a seed technique is present, its seeds are added as - ``children["technique_seeds"]``. + Includes the factory name, attack class, kwargs, adversarial config, + and the adversarial-flag booleans so factories with different + configurations produce different hashes. When a seed technique is + present, its seeds are added as ``children["technique_seeds"]``. Returns: ComponentIdentifier: The frozen identity snapshot. """ kwargs_for_id = {k: self._serialize_value(v) for k, v in sorted(self._attack_kwargs.items())} params: dict[str, Any] = { + "name": self._name, "attack_class": self._attack_class.__name__, "kwargs": kwargs_for_id, + "uses_adversarial": self._uses_adversarial, } + if self._strategy_tags: + params["strategy_tags"] = list(self._strategy_tags) if self._adversarial_config is not None: params["adversarial_config"] = self._serialize_value(self._adversarial_config) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index aa48c3daa1..19caa57fda 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -322,21 +322,30 @@ def _get_attack_technique_factories(self) -> dict[str, "AttackTechniqueFactory"] value is an ``AttackTechniqueFactory`` that can produce an ``AttackTechnique`` for that technique. - The base implementation lazily populates the - ``AttackTechniqueRegistry`` singleton with core techniques (via - ``ScenarioTechniqueRegistrar``) and returns all registered factories. + The base implementation returns every factory currently registered in + the ``AttackTechniqueRegistry`` singleton. The canonical scenario + techniques are populated by ``ScenarioTechniqueInitializer`` + (``pyrit.setup.initializers.components.scenario_techniques``); ensure + that initializer has run before scenarios use this method. Subclasses may override to add, remove, or replace factories. Returns: dict[str, AttackTechniqueFactory]: Mapping of technique name to factory. - """ - from pyrit.scenario.core.scenario_techniques import register_scenario_techniques - - register_scenario_techniques() + Raises: + RuntimeError: If the registry is empty (no initializer has run). + """ from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry - return AttackTechniqueRegistry.get_registry_singleton().get_factories() + registry = AttackTechniqueRegistry.get_registry_singleton() + factories = registry.get_factories() + if not factories: + raise RuntimeError( + "AttackTechniqueRegistry is empty. Run ScenarioTechniqueInitializer " + "(pyrit.setup.initializers.components.scenario_techniques) before " + "executing scenarios." + ) + return factories def _build_display_group(self, *, technique_name: str, seed_group_name: str) -> str: """ diff --git a/pyrit/scenario/core/scenario_techniques.py b/pyrit/scenario/core/scenario_techniques.py deleted file mode 100644 index 05a8e40c31..0000000000 --- a/pyrit/scenario/core/scenario_techniques.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Scenario attack technique definitions and registration. - -Provides ``SCENARIO_TECHNIQUES`` (the static catalog used for strategy enum -construction) and ``register_scenario_techniques`` (registers specs with -resolved live targets into the ``AttackTechniqueRegistry`` singleton). - -To add a new technique, append an ``AttackTechniqueSpec`` to -``SCENARIO_TECHNIQUES``. If the technique requires an adversarial chat -target, it will be automatically resolved in ``build_scenario_techniques`` -by inspecting the attack class constructor signature. To use a specific -adversarial chat target from ``TargetRegistry``, set -``adversarial_chat_key`` on the spec. -""" - -from __future__ import annotations - -import dataclasses -import inspect -import logging -from pathlib import Path -from typing import TYPE_CHECKING - -from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH -from pyrit.executor.attack import ( - ContextComplianceAttack, - ManyShotJailbreakAttack, - PromptSendingAttack, - RedTeamingAttack, - RolePlayAttack, - RolePlayPaths, - TreeOfAttacksWithPruningAttack, -) -from pyrit.models import SeedAttackTechniqueGroup, SeedSimulatedConversation -from pyrit.models.seeds.seed_simulated_conversation import NextMessageSystemPromptPaths -from pyrit.registry import TargetRegistry -from pyrit.registry.object_registries.attack_technique_registry import ( - AttackTechniqueRegistry, - AttackTechniqueSpec, -) -from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target - -if TYPE_CHECKING: - from pyrit.prompt_target import PromptTarget - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Static technique catalog -# --------------------------------------------------------------------------- -# Used for strategy enum construction (import-time safe — no live targets). -# Live dependencies (e.g. adversarial chat targets) are resolved later by -# build_scenario_techniques() at registration time. - -SCENARIO_TECHNIQUES: list[AttackTechniqueSpec] = [ - AttackTechniqueSpec( - name="prompt_sending", - attack_class=PromptSendingAttack, - strategy_tags=["core", "single_turn", "default", "light"], - ), - AttackTechniqueSpec( - name="role_play", - attack_class=RolePlayAttack, - strategy_tags=["core", "single_turn", "light"], - extra_kwargs={"role_play_definition_path": RolePlayPaths.MOVIE_SCRIPT.value}, - ), - AttackTechniqueSpec( - name="many_shot", - attack_class=ManyShotJailbreakAttack, - strategy_tags=["core", "multi_turn", "default", "light"], - ), - AttackTechniqueSpec( - name="tap", - attack_class=TreeOfAttacksWithPruningAttack, - strategy_tags=["core", "multi_turn"], - ), - AttackTechniqueSpec( - name="crescendo_simulated", - attack_class=PromptSendingAttack, - strategy_tags=["core", "single_turn"], - seed_technique=SeedAttackTechniqueGroup( - seeds=[ - SeedSimulatedConversation( - adversarial_chat_system_prompt_path=( - Path(EXECUTOR_SEED_PROMPT_PATH) / "red_teaming" / "crescendo_simulated.yaml" - ), - next_message_system_prompt_path=NextMessageSystemPromptPaths.DIRECT.value, - num_turns=3, - ), - ], - ), - ), - AttackTechniqueSpec( - name="red_teaming", - attack_class=RedTeamingAttack, - strategy_tags=["core", "multi_turn", "light"], - ), - AttackTechniqueSpec( - name="context_compliance", - attack_class=ContextComplianceAttack, - strategy_tags=["core", "single_turn", "light"], - ), -] - - -# --------------------------------------------------------------------------- -# Runtime spec builder -# --------------------------------------------------------------------------- - - -def build_scenario_techniques() -> list[AttackTechniqueSpec]: - """ - Return a copy of ``SCENARIO_TECHNIQUES`` with ``adversarial_chat`` baked - into each spec that requires one. - - This is a mechanical transform of the static catalog. - - Resolution order for each spec: - - 1. If ``adversarial_chat_key`` is set, look it up in ``TargetRegistry``. - Raises ``ValueError`` if the key is not found. - 2. Otherwise, if the attack class accepts ``attack_adversarial_config`` - or the spec's ``seed_technique`` has a simulated conversation, - fill in the default from ``get_default_adversarial_target()``. - 3. Otherwise, pass through unchanged. - - Returns: - list[AttackTechniqueSpec]: Specs ready for registration. - - Raises: - ValueError: If a spec declares ``adversarial_chat_key`` but the key - is not found in ``TargetRegistry``. - """ - default_adversarial: PromptTarget | None = None - - result = [] - for spec in SCENARIO_TECHNIQUES: - if spec.adversarial_chat_key: - registry = TargetRegistry.get_registry_singleton() - resolved = registry.get(spec.adversarial_chat_key) - if resolved is None: - raise ValueError( - f"Technique spec '{spec.name}' references adversarial_chat_key " - f"'{spec.adversarial_chat_key}', but no such entry exists in TargetRegistry." - ) - result.append( - dataclasses.replace( - spec, - adversarial_chat=resolved, - adversarial_chat_key=None, - ) - ) - elif _spec_needs_adversarial(spec): - if default_adversarial is None: - default_adversarial = get_default_adversarial_target() - result.append(dataclasses.replace(spec, adversarial_chat=default_adversarial)) - else: - result.append(spec) - return result - - -def _spec_needs_adversarial(spec: AttackTechniqueSpec) -> bool: - """ - Check if a spec requires an adversarial chat target. - - Returns: - True if the attack class accepts ``attack_adversarial_config`` - or the spec's seed technique has a simulated conversation. - """ - if "attack_adversarial_config" in inspect.signature(spec.attack_class.__init__).parameters: # type: ignore[misc] - return True - return spec.seed_technique is not None and spec.seed_technique.has_simulated_conversation - - -# --------------------------------------------------------------------------- -# Registration helper -# --------------------------------------------------------------------------- - - -def register_scenario_techniques() -> None: - """ - Register all ``SCENARIO_TECHNIQUES`` into the ``AttackTechniqueRegistry`` singleton. - - Per-name idempotent: existing entries are not overwritten. - - Resolves the default adversarial target, bakes it into the specs that - require it, then registers the resulting factories. - """ - specs = build_scenario_techniques() - - registry = AttackTechniqueRegistry.get_registry_singleton() - registry.register_from_specs(specs) diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index d29b81eecc..3d11ae452d 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -20,28 +20,31 @@ logger = logging.getLogger(__name__) -_CYBER_TECHNIQUE_NAMES = {"prompt_sending", "red_teaming"} +_CYBER_TECHNIQUE_NAMES = {"red_teaming"} def _build_cyber_strategy() -> type[ScenarioStrategy]: """ - Build the Cyber strategy class dynamically from SCENARIO_TECHNIQUES. + Build the Cyber strategy class dynamically from the registered technique factories. - Selects only ``prompt_sending`` and ``red_teaming`` techniques from - the shared catalog. + Selects only the ``red_teaming`` factory from the singleton + ``AttackTechniqueRegistry``. A plain ``PromptSendingAttack`` baseline is + prepended automatically by ``Scenario._build_baseline_atomic_attack`` via + ``BaselineAttackPolicy.Enabled``. Returns: type[ScenarioStrategy]: The dynamically generated strategy enum class. """ from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.registry.tag_query import TagQuery - from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES - cyber_specs = [s for s in SCENARIO_TECHNIQUES if s.name in _CYBER_TECHNIQUE_NAMES] + registry = AttackTechniqueRegistry.get_registry_singleton() + factories = registry.get_factories() + cyber_factories = [f for name, f in factories.items() if name in _CYBER_TECHNIQUE_NAMES] - return AttackTechniqueRegistry.build_strategy_class_from_specs( # type: ignore[ty:invalid-return-type] + return AttackTechniqueRegistry.build_strategy_class_from_factories( # type: ignore[ty:invalid-return-type] class_name="CyberStrategy", - specs=cyber_specs, + factories=cyber_factories, aggregate_tags={ "single_turn": TagQuery.any_of("single_turn"), "multi_turn": TagQuery.any_of("multi_turn"), diff --git a/pyrit/scenario/scenarios/airt/leakage.py b/pyrit/scenario/scenarios/airt/leakage.py index 2ccf54f768..16712f14f9 100644 --- a/pyrit/scenario/scenarios/airt/leakage.py +++ b/pyrit/scenario/scenarios/airt/leakage.py @@ -16,9 +16,9 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration from pyrit.registry.object_registries.attack_technique_registry import ( AttackTechniqueRegistry, - AttackTechniqueSpec, ) from pyrit.registry.tag_query import TagQuery +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.dataset_configuration import DatasetConfiguration from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioStrategy @@ -26,7 +26,6 @@ if TYPE_CHECKING: from pathlib import Path - from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.score import TrueFalseScorer @@ -38,22 +37,22 @@ _BLANK_IMAGE_PATH = str(DATASETS_PATH / "seed_datasets" / "local" / "examples" / "blank_canvas.png") -LEAKAGE_TECHNIQUES: list[AttackTechniqueSpec] = [ - AttackTechniqueSpec( +LEAKAGE_FACTORIES: list[AttackTechniqueFactory] = [ + AttackTechniqueFactory( name="first_letter", attack_class=PromptSendingAttack, strategy_tags=["single_turn", "default"], - extra_kwargs={ + attack_kwargs={ "attack_converter_config": AttackConverterConfig( request_converters=PromptConverterConfiguration.from_converters(converters=[FirstLetterConverter()]) ), }, ), - AttackTechniqueSpec( + AttackTechniqueFactory( name="image", attack_class=PromptSendingAttack, strategy_tags=["single_turn", "default"], - extra_kwargs={ + attack_kwargs={ "attack_converter_config": AttackConverterConfig( request_converters=PromptConverterConfiguration.from_converters( converters=[AddImageTextConverter(img_to_add=_BLANK_IMAGE_PATH)] @@ -66,20 +65,20 @@ def _build_leakage_strategy() -> type[ScenarioStrategy]: """ - Build the Leakage strategy class dynamically from core + leakage-specific techniques. + Build the Leakage strategy class dynamically from core + leakage-specific factories. - Combines core SCENARIO_TECHNIQUES with leakage-unique techniques (first_letter, image) - to provide a full set of attack strategies. + Combines core factories (from the registry) with leakage-unique factories + (``first_letter``, ``image``) to provide the full set of attack strategies. Returns: type[ScenarioStrategy]: The dynamically generated strategy enum class. """ - from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES - - all_specs = SCENARIO_TECHNIQUES + LEAKAGE_TECHNIQUES - return AttackTechniqueRegistry.build_strategy_class_from_specs( # type: ignore[return-value, ty:invalid-return-type] + registry = AttackTechniqueRegistry.get_registry_singleton() + core_factories = list(registry.get_factories().values()) + all_factories = core_factories + LEAKAGE_FACTORIES + return AttackTechniqueRegistry.build_strategy_class_from_factories( # type: ignore[return-value, ty:invalid-return-type] class_name="LeakageStrategy", - specs=all_specs, + factories=all_factories, aggregate_tags={ "default": TagQuery.any_of("default"), "single_turn": TagQuery.any_of("single_turn"), @@ -167,15 +166,16 @@ def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: """ Return core + leakage-specific attack technique factories. - Gets core factories from the base class, then builds leakage-specific - factories locally without registering them in the global registry. + Gets core factories from the base class, then merges in the + leakage-specific factories (kept local to this scenario so they don't + pollute the global registry). Returns: dict[str, AttackTechniqueFactory]: Mapping of technique names to their factories. """ factories = super()._get_attack_technique_factories() - for spec in LEAKAGE_TECHNIQUES: - factories[spec.name] = AttackTechniqueRegistry.build_factory_from_spec(spec) + for factory in LEAKAGE_FACTORIES: + factories[factory.name] = factory return factories diff --git a/pyrit/scenario/scenarios/airt/rapid_response.py b/pyrit/scenario/scenarios/airt/rapid_response.py index 41d853f214..b98ef052bc 100644 --- a/pyrit/scenario/scenarios/airt/rapid_response.py +++ b/pyrit/scenario/scenarios/airt/rapid_response.py @@ -28,20 +28,23 @@ def _build_rapid_response_strategy() -> type[ScenarioStrategy]: """ - Build the RapidResponse strategy class dynamically from SCENARIO_TECHNIQUES. + Build the RapidResponse strategy class dynamically from the registered factories. - Reads the spec list (pure data) — no registry interaction or target resolution. + Reads the singleton ``AttackTechniqueRegistry`` and filters to factories + tagged ``core``. Returns: type[ScenarioStrategy]: The dynamically generated strategy enum class. """ from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.registry.tag_query import TagQuery - from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES - return AttackTechniqueRegistry.build_strategy_class_from_specs( # type: ignore[ty:invalid-return-type] + registry = AttackTechniqueRegistry.get_registry_singleton() + factories = list(registry.get_factories().values()) + + return AttackTechniqueRegistry.build_strategy_class_from_factories( # type: ignore[ty:invalid-return-type] class_name="RapidResponseStrategy", - specs=TagQuery.all("core").filter(SCENARIO_TECHNIQUES), + factories=TagQuery.all("core").filter(factories), aggregate_tags={ "default": TagQuery.any_of("default"), "single_turn": TagQuery.any_of("single_turn"), diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index dfec12839c..e6a357d8b3 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -11,15 +11,15 @@ from pyrit.common import apply_defaults from pyrit.executor.attack import AttackAdversarialConfig, AttackScoringConfig from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS -from pyrit.registry import AttackTechniqueRegistry, AttackTechniqueSpec +from pyrit.registry import AttackTechniqueRegistry from pyrit.registry.tag_query import TagQuery from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.dataset_configuration import DatasetConfiguration from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario -from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES if TYPE_CHECKING: from pyrit.prompt_target import PromptTarget + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.score import TrueFalseScorer @@ -145,9 +145,9 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Build atomic attacks from the cross-product of techniques × models × datasets. - Factories are built locally from adversarial-capable ``SCENARIO_TECHNIQUES`` - (not the registry singleton). Each model is injected at create-time via - ``attack_adversarial_config_override``. + Factories are read from the singleton ``AttackTechniqueRegistry`` and + narrowed to adversarial-capable ones. Each model is injected at + create-time via ``attack_adversarial_config_override``. Returns: list[AtomicAttack]: One atomic attack per technique/model/dataset combination. @@ -160,10 +160,8 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: "Scenario not properly initialized. Call await scenario.initialize_async() before running." ) - benchmarkable_specs = AdversarialBenchmark._get_benchmarkable_specs() - local_factories = { - spec.name: AttackTechniqueRegistry.build_factory_from_spec(spec) for spec in benchmarkable_specs - } + benchmarkable_factories = AdversarialBenchmark._get_benchmarkable_factories() + local_factories = {factory.name: factory for factory in benchmarkable_factories} selected_techniques = {s.value for s in self._scenario_strategies} seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() @@ -255,19 +253,17 @@ def _infer_labels( @staticmethod def _build_benchmark_strategy() -> type[ScenarioStrategy]: """ - Build the BenchmarkStrategy enum from adversarial-capable ``SCENARIO_TECHNIQUES``. + Build the BenchmarkStrategy enum from adversarial-capable factories. Returns a strategy class whose concrete members are adversarial-capable - techniques (no baked-in adversarial chat) and whose aggregates allow - selecting by turn style. + techniques and whose aggregates allow selecting by turn style. Returns: type[ScenarioStrategy]: The dynamically generated strategy enum class. """ - specs = AdversarialBenchmark._get_benchmarkable_specs() - return AttackTechniqueRegistry.build_strategy_class_from_specs( # type: ignore[ty:invalid-return-type] + return AttackTechniqueRegistry.build_strategy_class_from_factories( # type: ignore[ty:invalid-return-type] class_name="BenchmarkStrategy", - specs=TagQuery.all("core").filter(specs), + factories=AdversarialBenchmark._get_benchmarkable_factories(), aggregate_tags={ "default": TagQuery.any_of("default"), "single_turn": TagQuery.any_of("single_turn"), @@ -277,20 +273,21 @@ def _build_benchmark_strategy() -> type[ScenarioStrategy]: ) @staticmethod - def _get_benchmarkable_specs() -> list[AttackTechniqueSpec]: + def _get_benchmarkable_factories() -> list[AttackTechniqueFactory]: """ - Return techniques from ``SCENARIO_TECHNIQUES`` that accept an adversarial - model but don't have one already baked in. + Return ``core`` factories that drive an adversarial chat. - This is the dual guard: ``_accepts_adversarial`` ensures the technique - CAN use an adversarial model, and ``adversarial_chat is None`` ensures - it doesn't already have one set — we inject our own at create-time. + Every benchmark technique must accept an adversarial-config override at + ``create()`` time so the scenario can inject one chat per benchmark + model. We narrow to the ``core`` tag to exclude experimental / persona + variants. Returns: - list[AttackTechniqueSpec]: Filtered, adversarial-ready specs. + list[AttackTechniqueFactory]: Filtered core, adversarial-capable factories. """ + registry = AttackTechniqueRegistry.get_registry_singleton() return [ - spec - for spec in SCENARIO_TECHNIQUES - if AttackTechniqueRegistry._accepts_adversarial(spec.attack_class) and spec.adversarial_chat is None + factory + for factory in registry.get_factories().values() + if factory.uses_adversarial and "core" in factory.strategy_tags ] diff --git a/pyrit/setup/initializers/__init__.py b/pyrit/setup/initializers/__init__.py index 84aeb83a49..da9e65dc5d 100644 --- a/pyrit/setup/initializers/__init__.py +++ b/pyrit/setup/initializers/__init__.py @@ -6,7 +6,7 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.common.parameter import Parameter from pyrit.setup.initializers.airt import AIRTInitializer -from pyrit.setup.initializers.components.scenarios import ScenarioTechniqueInitializer +from pyrit.setup.initializers.components.scenario_techniques import ScenarioTechniqueInitializer from pyrit.setup.initializers.components.scorers import ScorerInitializer from pyrit.setup.initializers.components.targets import TargetInitializer from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer diff --git a/pyrit/setup/initializers/components/__init__.py b/pyrit/setup/initializers/components/__init__.py index 7da38c774c..ba2dd6f32b 100644 --- a/pyrit/setup/initializers/components/__init__.py +++ b/pyrit/setup/initializers/components/__init__.py @@ -3,7 +3,7 @@ """Component initializers for targets, scorers, and other components.""" -from pyrit.setup.initializers.components.scenarios import ScenarioTechniqueInitializer +from pyrit.setup.initializers.components.scenario_techniques import ScenarioTechniqueInitializer from pyrit.setup.initializers.components.scorers import ScorerInitializer, ScorerInitializerTags from pyrit.setup.initializers.components.targets import TargetConfig, TargetInitializer, TargetInitializerTags diff --git a/pyrit/setup/initializers/components/scenario_techniques.py b/pyrit/setup/initializers/components/scenario_techniques.py new file mode 100644 index 0000000000..6f8b990f0b --- /dev/null +++ b/pyrit/setup/initializers/components/scenario_techniques.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario technique initializer. + +This module owns the canonical catalog of scenario attack techniques as a +flat list of self-describing :class:`AttackTechniqueFactory` instances and +registers them into the singleton :class:`AttackTechniqueRegistry` via +:class:`ScenarioTechniqueInitializer`. + +Per-name registration is idempotent: pre-existing entries in the registry are +not overwritten. +""" + +from __future__ import annotations + +import logging + +from pyrit.executor.attack import ( + ContextComplianceAttack, + ManyShotJailbreakAttack, + RedTeamingAttack, + RolePlayAttack, + RolePlayPaths, + TreeOfAttacksWithPruningAttack, +) +from pyrit.registry.object_registries.attack_technique_registry import ( + AttackTechniqueRegistry, +) +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +logger = logging.getLogger(__name__) + + +def build_scenario_technique_factories() -> list[AttackTechniqueFactory]: + """Build the canonical scenario technique factories. + + Factories that need an adversarial chat target do not bake one in; the + default adversarial target is resolved lazily inside + :meth:`AttackTechniqueFactory.create` via + ``get_default_adversarial_target()``. Scenarios may also pass + ``attack_adversarial_config_override`` at create time (but only when the + factory did not bake one in at construction). + + A bare ``PromptSendingAttack`` factory is intentionally omitted from the + catalog: every scenario whose ``BASELINE_ATTACK_POLICY`` is + :attr:`BaselineAttackPolicy.Enabled` already auto-prepends an equivalent + baseline atomic attack via ``Scenario._build_baseline_atomic_attack``. + + Returns: + list[AttackTechniqueFactory]: The full catalog of scenario techniques. + """ + return [ + AttackTechniqueFactory( + name="role_play", + attack_class=RolePlayAttack, + strategy_tags=["core", "single_turn", "default", "light"], + attack_kwargs={"role_play_definition_path": RolePlayPaths.MOVIE_SCRIPT.value}, + ), + AttackTechniqueFactory( + name="many_shot", + attack_class=ManyShotJailbreakAttack, + strategy_tags=["core", "multi_turn", "default", "light"], + ), + AttackTechniqueFactory( + name="tap", + attack_class=TreeOfAttacksWithPruningAttack, + strategy_tags=["core", "multi_turn"], + ), + AttackTechniqueFactory.with_simulated_conversation( + name="crescendo_simulated", + strategy_tags=["core", "single_turn"], + ), + AttackTechniqueFactory( + name="red_teaming", + attack_class=RedTeamingAttack, + strategy_tags=["core", "multi_turn", "light"], + ), + AttackTechniqueFactory( + name="context_compliance", + attack_class=ContextComplianceAttack, + strategy_tags=["core", "single_turn", "light"], + ), + AttackTechniqueFactory.with_simulated_conversation( + name="crescendo_movie_director", + strategy_tags=["core", "single_turn"], + ), + AttackTechniqueFactory.with_simulated_conversation( + name="crescendo_history_lecture", + strategy_tags=["core", "single_turn"], + ), + AttackTechniqueFactory.with_simulated_conversation( + name="crescendo_journalist_interview", + strategy_tags=["core", "single_turn"], + ), + ] + + +class ScenarioTechniqueInitializer(PyRITInitializer): + """Register the canonical scenario attack technique factories. + + Builds and registers the 6 core techniques (``role_play``, ``many_shot``, + ``tap``, ``crescendo_simulated``, ``red_teaming``, ``context_compliance``) + together with the persona-driven crescendo variants + (``crescendo_movie_director``, ``crescendo_history_lecture``, + ``crescendo_journalist_interview``). + + A bare ``PromptSendingAttack`` factory is intentionally not registered: the + scenario-level baseline (``BaselineAttackPolicy.Enabled`` + + ``Scenario._build_baseline_atomic_attack``) already covers that case. + + Registration is per-name idempotent: pre-existing entries in + :class:`AttackTechniqueRegistry` are not overwritten. + """ + + async def initialize_async(self) -> None: + """Build the canonical factories and register them into the singleton registry.""" + factories = build_scenario_technique_factories() + + registry = AttackTechniqueRegistry.get_registry_singleton() + registry.register_from_factories(factories) + + registered_names = [f.name for f in factories if f.name in registry] + logger.info( + "Registered %d scenario technique factory(ies): %s", + len(registered_names), + ", ".join(registered_names), + ) diff --git a/pyrit/setup/initializers/components/scenarios.py b/pyrit/setup/initializers/components/scenarios.py deleted file mode 100644 index c12f81740c..0000000000 --- a/pyrit/setup/initializers/components/scenarios.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Scenario Technique Initializer for registering persona-driven crescendo techniques. - -This module provides the ScenarioTechniqueInitializer class that registers -additional ``AttackTechniqueSpec`` entries into the singleton -``AttackTechniqueRegistry``, on top of the core specs declared in -``pyrit.scenario.core.scenario_techniques.SCENARIO_TECHNIQUES``. - -The techniques registered here are persona-driven YAML variants of the canonical -``crescendo_simulated`` technique introduced in PR #1665. They reuse -``PromptSendingAttack`` plus a ``SeedSimulatedConversation`` whose adversarial -chat is driven by a persona-specific YAML system prompt. No new attack -primitives are introduced. - -Per-name registration is idempotent: existing entries in the registry are not -overwritten. -""" - -import dataclasses -import logging -from pathlib import Path - -from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH -from pyrit.executor.attack import PromptSendingAttack -from pyrit.models import SeedAttackTechniqueGroup, SeedSimulatedConversation -from pyrit.models.seeds.seed_simulated_conversation import NextMessageSystemPromptPaths -from pyrit.registry.object_registries.attack_technique_registry import ( - AttackTechniqueRegistry, - AttackTechniqueSpec, -) -from pyrit.scenario.core.scenario_techniques import ( - get_default_adversarial_target, - register_scenario_techniques, -) -from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer - -logger = logging.getLogger(__name__) - - -# Names of the persona-driven crescendo techniques registered by this initializer. -# Each name corresponds to a YAML file under -# ``pyrit/datasets/executors/red_teaming/.yaml``. -CRESCENDO_MOVIE_DIRECTOR: str = "crescendo_movie_director" -CRESCENDO_HISTORY_LECTURE: str = "crescendo_history_lecture" -CRESCENDO_JOURNALIST_INTERVIEW: str = "crescendo_journalist_interview" - -PERSONA_CRESCENDO_TECHNIQUE_NAMES: list[str] = [ - CRESCENDO_MOVIE_DIRECTOR, - CRESCENDO_HISTORY_LECTURE, - CRESCENDO_JOURNALIST_INTERVIEW, -] - - -def _build_persona_crescendo_spec(*, name: str) -> AttackTechniqueSpec: - """ - Build a persona-driven crescendo ``AttackTechniqueSpec``. - - Mirrors the wiring of the canonical ``crescendo_simulated`` spec from - ``pyrit.scenario.core.scenario_techniques``: ``PromptSendingAttack`` plus a - ``SeedSimulatedConversation`` whose adversarial chat reads its system prompt - from ``pyrit/datasets/executors/red_teaming/.yaml``. ``num_turns`` - matches the canonical default of 3. - - Args: - name: The technique name. Must match the YAML filename stem under - ``pyrit/datasets/executors/red_teaming/``. - - Returns: - AttackTechniqueSpec: A spec ready for adversarial-chat resolution and - registration via ``AttackTechniqueRegistry.register_from_specs``. - """ - return AttackTechniqueSpec( - name=name, - attack_class=PromptSendingAttack, - strategy_tags=["core", "single_turn"], - seed_technique=SeedAttackTechniqueGroup( - seeds=[ - SeedSimulatedConversation( - adversarial_chat_system_prompt_path=( - Path(EXECUTOR_SEED_PROMPT_PATH) / "red_teaming" / f"{name}.yaml" - ), - next_message_system_prompt_path=NextMessageSystemPromptPaths.DIRECT.value, - num_turns=3, - ), - ], - ), - ) - - -def build_persona_crescendo_specs() -> list[AttackTechniqueSpec]: - """ - Build the full set of persona-driven crescendo specs registered by this initializer. - - Returns: - list[AttackTechniqueSpec]: One spec per persona variant, in registration order. - """ - return [_build_persona_crescendo_spec(name=name) for name in PERSONA_CRESCENDO_TECHNIQUE_NAMES] - - -class ScenarioTechniqueInitializer(PyRITInitializer): - """ - Register persona-driven crescendo scenario techniques into the registry. - - This initializer first ensures the core ``SCENARIO_TECHNIQUES`` are registered - (via ``register_scenario_techniques``), then appends the persona-driven - crescendo variants. Each variant is wired with the same default adversarial - chat target as ``crescendo_simulated``, since they share the - ``SeedSimulatedConversation`` shape. - - Registration is per-name idempotent: pre-existing entries in - ``AttackTechniqueRegistry`` are not overwritten. - """ - - async def initialize_async(self) -> None: - """ - Register the persona-driven crescendo specs into the singleton registry. - - First ensures the core ``SCENARIO_TECHNIQUES`` are registered, then - builds and registers each persona variant with the default adversarial - chat target baked in. Registration is per-name idempotent. - """ - register_scenario_techniques() - - default_adversarial = get_default_adversarial_target() - persona_specs = [ - dataclasses.replace(spec, adversarial_chat=default_adversarial) for spec in build_persona_crescendo_specs() - ] - - registry = AttackTechniqueRegistry.get_registry_singleton() - registry.register_from_specs(persona_specs) - - registered_names = [spec.name for spec in persona_specs if spec.name in registry] - logger.info( - "Registered %d persona-driven crescendo technique(s): %s", - len(registered_names), - ", ".join(registered_names), - ) diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 0471786b36..6aab39f0f9 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -39,7 +39,7 @@ def _make_scenario_metadata( class_name: str = "TestScenario", description: str = "A test scenario", default_strategy: str = "default", - all_strategies: tuple[str, ...] = ("prompt_sending", "role_play"), + all_strategies: tuple[str, ...] = ("role_play", "many_shot"), aggregate_strategies: tuple[str, ...] = ("all", "default"), default_datasets: tuple[str, ...] = ("test_dataset",), max_dataset_size: int | None = None, @@ -95,7 +95,7 @@ async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: assert result.items[0].description == "A test scenario" assert result.items[0].default_strategy == "default" assert result.items[0].aggregate_strategies == ["all", "default"] - assert result.items[0].all_strategies == ["prompt_sending", "role_play"] + assert result.items[0].all_strategies == ["role_play", "many_shot"] assert result.items[0].default_datasets == ["test_dataset"] assert result.items[0].max_dataset_size is None @@ -229,7 +229,7 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: description="Red team agent testing", default_strategy="default", aggregate_strategies=["all", "default"], - all_strategies=["prompt_sending", "role_play"], + all_strategies=["role_play", "many_shot"], default_datasets=["airt_hate"], max_dataset_size=10, ) @@ -254,7 +254,7 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: assert item["scenario_type"] == "RedTeamAgentScenario" assert item["default_strategy"] == "default" assert item["aggregate_strategies"] == ["all", "default"] - assert item["all_strategies"] == ["prompt_sending", "role_play"] + assert item["all_strategies"] == ["role_play", "many_shot"] assert item["default_datasets"] == ["airt_hate"] assert item["max_dataset_size"] == 10 @@ -283,7 +283,7 @@ def test_get_scenario_returns_200(self, client: TestClient) -> None: description="Red team agent testing", default_strategy="default", aggregate_strategies=["all"], - all_strategies=["prompt_sending"], + all_strategies=["role_play"], default_datasets=["airt_hate"], max_dataset_size=None, ) @@ -351,7 +351,7 @@ async def test_list_scenarios_includes_supported_parameters(self) -> None: class_module="pyrit.scenario.scenarios.param", class_description="A scenario with params", default_strategy="default", - all_strategies=("prompt_sending",), + all_strategies=("role_play",), aggregate_strategies=("all",), default_datasets=("test_dataset",), max_dataset_size=None, diff --git a/tests/unit/registry/test_attack_technique_registry.py b/tests/unit/registry/test_attack_technique_registry.py index 71e942ed7f..25d8c0def8 100644 --- a/tests/unit/registry/test_attack_technique_registry.py +++ b/tests/unit/registry/test_attack_technique_registry.py @@ -11,9 +11,10 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.identifiers import ComponentIdentifier from pyrit.prompt_target import PromptTarget -from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry, AttackTechniqueSpec +from pyrit.registry import TargetRegistry +from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory, ScorerOverridePolicy -from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES +from pyrit.setup.initializers.components.scenario_techniques import build_scenario_technique_factories class _StubAttack: @@ -85,7 +86,7 @@ def teardown_method(self): AttackTechniqueRegistry.reset_instance() def test_register_technique_stores_factory(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="stub_attack", attack_class=_StubAttack) self.registry.register_technique(name="stub_attack", factory=factory) @@ -93,7 +94,7 @@ def test_register_technique_stores_factory(self): assert self.registry._registry_items["stub_attack"].instance is factory def test_register_technique_with_tags(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="stub_attack", attack_class=_StubAttack) self.registry.register_technique( name="stub_attack", @@ -106,8 +107,9 @@ def test_register_technique_with_tags(self): assert entries[0].name == "stub_attack" def test_register_multiple_techniques(self): - factory1 = AttackTechniqueFactory(attack_class=_StubAttack) + factory1 = AttackTechniqueFactory(name="stub_5", attack_class=_StubAttack) factory2 = AttackTechniqueFactory( + name="stub_20", attack_class=_StubAttack, attack_kwargs={"max_turns": 20}, ) @@ -130,7 +132,7 @@ def teardown_method(self): AttackTechniqueRegistry.reset_instance() def test_build_metadata_returns_component_identifier(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="stub", attack_class=_StubAttack) self.registry.register_technique(name="stub", factory=factory) metadata = self.registry.list_metadata() @@ -140,7 +142,7 @@ def test_build_metadata_returns_component_identifier(self): assert metadata[0].class_name == "AttackTechniqueFactory" def test_metadata_matches_factory_identifier(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="stub", attack_class=_StubAttack) self.registry.register_technique(name="stub", factory=factory) metadata = self.registry.list_metadata() @@ -159,7 +161,7 @@ def teardown_method(self): AttackTechniqueRegistry.reset_instance() def test_contains(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="exists", attack_class=_StubAttack) self.registry.register_technique(name="exists", factory=factory) assert "exists" in self.registry @@ -168,22 +170,26 @@ def test_contains(self): def test_len(self): assert len(self.registry) == 0 - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="a", attack_class=_StubAttack) self.registry.register_technique(name="a", factory=factory) assert len(self.registry) == 1 def test_get_names_returns_sorted(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) - self.registry.register_technique(name="zeta", factory=factory) - self.registry.register_technique(name="alpha", factory=factory) - self.registry.register_technique(name="beta", factory=factory) + factory_zeta = AttackTechniqueFactory(name="zeta", attack_class=_StubAttack) + factory_alpha = AttackTechniqueFactory(name="alpha", attack_class=_StubAttack) + factory_beta = AttackTechniqueFactory(name="beta", attack_class=_StubAttack) + self.registry.register_technique(name="zeta", factory=factory_zeta) + self.registry.register_technique(name="alpha", factory=factory_alpha) + self.registry.register_technique(name="beta", factory=factory_beta) assert self.registry.get_names() == ["alpha", "beta", "zeta"] def test_tag_based_queries(self): - factory1 = AttackTechniqueFactory(attack_class=_StubAttack) - factory2 = AttackTechniqueFactory(attack_class=_StubAttack, attack_kwargs={"max_turns": 20}) + factory1 = AttackTechniqueFactory(name="f1", attack_class=_StubAttack) + factory2 = AttackTechniqueFactory( + name="f2", attack_class=_StubAttack, attack_kwargs={"max_turns": 20} + ) self.registry.register_technique(name="f1", factory=factory1, tags=["multi_turn"]) self.registry.register_technique(name="f2", factory=factory2, tags=["single_turn"]) @@ -197,15 +203,18 @@ def test_tag_based_queries(self): assert single[0].name == "f2" def test_iter_yields_sorted_names(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) - self.registry.register_technique(name="b", factory=factory) - self.registry.register_technique(name="a", factory=factory) + factory_b = AttackTechniqueFactory(name="b", attack_class=_StubAttack) + factory_a = AttackTechniqueFactory(name="a", attack_class=_StubAttack) + self.registry.register_technique(name="b", factory=factory_b) + self.registry.register_technique(name="a", factory=factory_a) assert list(self.registry) == ["a", "b"] def test_get_factories_returns_dict_mapping(self): - factory_a = AttackTechniqueFactory(attack_class=_StubAttack) - factory_b = AttackTechniqueFactory(attack_class=_StubAttack, attack_kwargs={"max_turns": 5}) + factory_a = AttackTechniqueFactory(name="alpha", attack_class=_StubAttack) + factory_b = AttackTechniqueFactory( + name="beta", attack_class=_StubAttack, attack_kwargs={"max_turns": 5} + ) self.registry.register_technique(name="alpha", factory=factory_a) self.registry.register_technique(name="beta", factory=factory_b) @@ -240,42 +249,56 @@ def test_policy_is_read_only(self): with pytest.raises(AttributeError): self.registry.scorer_override_policy = ScorerOverridePolicy.RAISE - def test_policy_passed_to_factories_via_register_from_specs(self): - """Factories built via register_from_specs inherit the registry's default policy.""" - spec = AttackTechniqueSpec(name="stub_policy", attack_class=_StubAttack, strategy_tags=["test"]) - self.registry.register_from_specs([spec]) + def test_policy_passed_to_factories_via_register_from_factories(self): + """Factories registered via register_from_factories inherit the registry's default policy.""" + factory = AttackTechniqueFactory( + name="stub_policy", attack_class=_StubAttack, strategy_tags=["test"] + ) + self.registry.register_from_factories([factory]) + + stored = self.registry._registry_items["stub_policy"].instance + assert stored._scorer_override_policy == ScorerOverridePolicy.WARN - factory = self.registry._registry_items["stub_policy"].instance - assert factory._scorer_override_policy == ScorerOverridePolicy.WARN +SCENARIO_FACTORIES_FIXTURE: list[AttackTechniqueFactory] = [] -class TestScenarioTechniqueSpecsValid: - """Validate that every AttackTechniqueSpec in SCENARIO_TECHNIQUES is well-formed.""" - @pytest.mark.parametrize("spec", SCENARIO_TECHNIQUES, ids=lambda s: s.name) - def test_spec_extra_kwargs_match_attack_class_constructor(self, spec: AttackTechniqueSpec): - """Each spec's extra_kwargs must be valid parameters of its attack_class.""" - factory = AttackTechniqueRegistry.build_factory_from_spec(spec) - assert factory.attack_class is spec.attack_class +def _scenario_factories() -> list[AttackTechniqueFactory]: + """Build the scenario technique factories once for parametrization. - @pytest.mark.parametrize("spec", SCENARIO_TECHNIQUES, ids=lambda s: s.name) - def test_spec_attack_class_accepts_objective_target(self, spec: AttackTechniqueSpec): - """Every attack class must accept objective_target (required at create time).""" - sig = inspect.signature(spec.attack_class.__init__) + Uses a mock adversarial target in ``TargetRegistry`` so the build does + not depend on environment variables or OpenAIChatTarget. + """ + if not SCENARIO_FACTORIES_FIXTURE: + TargetRegistry.reset_instance() + adv_target = MagicMock(spec=PromptTarget) + adv_target.capabilities.includes.return_value = True + TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + SCENARIO_FACTORIES_FIXTURE.extend(build_scenario_technique_factories()) + return SCENARIO_FACTORIES_FIXTURE + + +class TestScenarioTechniqueFactoriesValid: + """Validate that every factory built by ``build_scenario_technique_factories`` is well-formed.""" + + @pytest.mark.parametrize("factory", _scenario_factories(), ids=lambda f: f.name) + def test_factory_attack_class_set(self, factory: AttackTechniqueFactory): + """Each factory references an attack class.""" + assert factory.attack_class is not None + + @pytest.mark.parametrize("factory", _scenario_factories(), ids=lambda f: f.name) + def test_factory_attack_class_accepts_objective_target(self, factory: AttackTechniqueFactory): + """Every attack class must accept ``objective_target`` (required at create time).""" + sig = inspect.signature(factory.attack_class.__init__) assert "objective_target" in sig.parameters, ( - f"{spec.attack_class.__name__} is missing required 'objective_target' parameter" + f"{factory.attack_class.__name__} is missing required 'objective_target' parameter" ) - def test_spec_names_are_unique(self): - """No two specs should share the same name.""" - names = [spec.name for spec in SCENARIO_TECHNIQUES] - assert len(names) == len(set(names)), f"Duplicate spec names: {[n for n in names if names.count(n) > 1]}" - - @pytest.mark.parametrize("spec", SCENARIO_TECHNIQUES, ids=lambda s: s.name) - def test_spec_adversarial_fields_not_both_set(self, spec: AttackTechniqueSpec): - """adversarial_chat and adversarial_chat_key must be mutually exclusive.""" - assert not (spec.adversarial_chat and spec.adversarial_chat_key), ( - f"Spec '{spec.name}' sets both adversarial_chat and adversarial_chat_key" + def test_factory_names_are_unique(self): + """No two factories should share the same name.""" + names = [f.name for f in _scenario_factories()] + assert len(names) == len(set(names)), ( + f"Duplicate factory names: {[n for n in names if names.count(n) > 1]}" ) @@ -295,12 +318,21 @@ def _make_generic_scoring_config(self): mock_scorer = MagicMock(spec=TrueFalseScorer) return AttackScoringConfig(objective_scorer=mock_scorer) + def _make_adversarial_config(self): + """Create an AttackAdversarialConfig wrapping a mock chat target.""" + from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig + + chat = MagicMock(spec=PromptTarget) + return AttackAdversarialConfig(target=chat) + def test_tap_factory_rejects_generic_config_with_raise_policy(self): """TAP factory raises when given a generic AttackScoringConfig and policy is RAISE.""" from pyrit.executor.attack.multi_turn.tree_of_attacks import TreeOfAttacksWithPruningAttack factory = AttackTechniqueFactory( + name="tap_raise", attack_class=TreeOfAttacksWithPruningAttack, + adversarial_config=self._make_adversarial_config(), scorer_override_policy=ScorerOverridePolicy.RAISE, ) @@ -320,24 +352,28 @@ def test_tap_factory_warns_on_generic_config_with_warn_policy(self, caplog): from pyrit.executor.attack.multi_turn.tree_of_attacks import TreeOfAttacksWithPruningAttack factory = AttackTechniqueFactory( + name="tap_warn", attack_class=TreeOfAttacksWithPruningAttack, + adversarial_config=self._make_adversarial_config(), scorer_override_policy=ScorerOverridePolicy.WARN, ) generic_config = self._make_generic_scoring_config() target = MagicMock(spec=PromptTarget) - # TAP will fail downstream (missing adversarial config), but the scorer - # override should be skipped with a warning — not a scorer ValueError. + # Under WARN policy, the scorer override should be skipped with a warning + # rather than raising. The factory.create() call may succeed or fail for + # unrelated downstream reasons — we only assert that no scorer-incompatibility + # ValueError was raised and that a warning was emitted. with caplog.at_level(logging.WARNING): - with pytest.raises(Exception) as exc_info: + try: factory.create( objective_target=target, attack_scoring_config=generic_config, ) + except Exception as exc: + assert "incompatible" not in str(exc).lower() - # The downstream error should NOT be about scorer incompatibility - assert "incompatible" not in str(exc_info.value).lower() # A warning about incompatibility should be logged assert any("incompatible" in record.message.lower() for record in caplog.records) @@ -348,24 +384,29 @@ def test_tap_factory_silently_skips_on_generic_config_with_skip_policy(self, cap from pyrit.executor.attack.multi_turn.tree_of_attacks import TreeOfAttacksWithPruningAttack factory = AttackTechniqueFactory( + name="tap_skip", attack_class=TreeOfAttacksWithPruningAttack, + adversarial_config=self._make_adversarial_config(), scorer_override_policy=ScorerOverridePolicy.SKIP, ) generic_config = self._make_generic_scoring_config() target = MagicMock(spec=PromptTarget) + # Under SKIP policy, the scorer override should be skipped silently. The + # factory.create() call may succeed or fail for unrelated downstream reasons + # — we only assert that no scorer-incompatibility error or warning was emitted. with caplog.at_level(logging.WARNING): - with pytest.raises(Exception) as exc_info: + try: factory.create( objective_target=target, attack_scoring_config=generic_config, ) + except Exception as exc: + assert "incompatible" not in str(exc).lower() # No warning about incompatibility should be logged assert not any("incompatible" in record.message.lower() for record in caplog.records) - # Downstream error should not mention scorer incompatibility - assert "incompatible" not in str(exc_info.value).lower() def test_tap_factory_accepts_tap_scoring_config(self): """TAP factory forwards TAPAttackScoringConfig regardless of policy.""" @@ -376,7 +417,9 @@ def test_tap_factory_accepts_tap_scoring_config(self): from pyrit.score import FloatScaleThresholdScorer factory = AttackTechniqueFactory( + name="tap_accept", attack_class=TreeOfAttacksWithPruningAttack, + adversarial_config=self._make_adversarial_config(), scorer_override_policy=ScorerOverridePolicy.RAISE, ) @@ -385,16 +428,16 @@ def test_tap_factory_accepts_tap_scoring_config(self): tap_config = TAPAttackScoringConfig(objective_scorer=mock_scorer) target = MagicMock(spec=PromptTarget) - # TAP will fail downstream (adversarial config missing), but - # the factory should NOT raise about scorer incompatibility - with pytest.raises(Exception) as exc_info: + # The factory should NOT raise about scorer incompatibility for a TAP-typed + # scoring config. Downstream construction may succeed or fail for unrelated + # reasons — we only assert no scorer-compatibility error is raised. + try: factory.create( objective_target=target, attack_scoring_config=tap_config, ) - - # The error should NOT be about scorer compatibility - assert "incompatible" not in str(exc_info.value).lower() + except Exception as exc: + assert "incompatible" not in str(exc).lower() def test_prompt_sending_factory_accepts_any_config(self): """PromptSendingAttack accepts base AttackScoringConfig — any config passes through.""" @@ -402,6 +445,7 @@ def test_prompt_sending_factory_accepts_any_config(self): from pyrit.memory import CentralMemory factory = AttackTechniqueFactory( + name="ps_any", attack_class=PromptSendingAttack, scorer_override_policy=ScorerOverridePolicy.RAISE, ) @@ -429,6 +473,7 @@ def test_prompt_sending_factory_accepts_tap_scoring_config(self): from pyrit.score import FloatScaleThresholdScorer factory = AttackTechniqueFactory( + name="ps_tap", attack_class=PromptSendingAttack, scorer_override_policy=ScorerOverridePolicy.RAISE, ) @@ -453,6 +498,7 @@ def test_prompt_sending_factory_accepts_tap_scoring_config(self): def test_factory_raises_when_attack_has_no_scoring_param_and_policy_raise(self): """Factory raises when attack doesn't accept attack_scoring_config and policy is RAISE.""" factory = AttackTechniqueFactory( + name="stub_noscorer_raise", attack_class=_StubAttackNoScorer, scorer_override_policy=ScorerOverridePolicy.RAISE, ) @@ -471,6 +517,7 @@ def test_factory_warns_when_attack_has_no_scoring_param_and_policy_warn(self, ca import logging factory = AttackTechniqueFactory( + name="stub_noscorer_warn", attack_class=_StubAttackNoScorer, scorer_override_policy=ScorerOverridePolicy.WARN, ) @@ -492,6 +539,7 @@ def test_factory_skips_silently_when_attack_has_no_scoring_param_and_policy_skip import logging factory = AttackTechniqueFactory( + name="stub_noscorer_skip", attack_class=_StubAttackNoScorer, scorer_override_policy=ScorerOverridePolicy.SKIP, ) diff --git a/tests/unit/scenario/test_adversarial.py b/tests/unit/scenario/test_adversarial.py index 5914a40ba9..1d652e29b4 100644 --- a/tests/unit/scenario/test_adversarial.py +++ b/tests/unit/scenario/test_adversarial.py @@ -3,8 +3,6 @@ """Tests for the AdversarialBenchmark scenario.""" -import copy -from dataclasses import FrozenInstanceError from unittest.mock import MagicMock, patch import pytest @@ -21,27 +19,45 @@ SeedPrompt, ) from pyrit.prompt_target import PromptTarget, TargetCapabilities, TargetConfiguration +from pyrit.registry import TargetRegistry from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy from pyrit.scenario.core.dataset_configuration import DatasetConfiguration -from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES from pyrit.scenario.scenarios.benchmark.adversarial import AdversarialBenchmark from pyrit.score import TrueFalseScorer +from pyrit.setup.initializers.components.scenario_techniques import build_scenario_technique_factories -# Self-pinned: any change to ``_get_benchmarkable_specs`` (or to the ``light`` tag -# membership in SCENARIO_TECHNIQUES) is reflected automatically — no magic numbers. -# -# ``_BENCHMARKABLE_*`` covers every adversarial-capable spec (used to verify the -# strategy enum's full concrete-member roster). ``_LIGHT_BENCHMARKABLE_*`` covers -# only the subset tagged ``"light"`` (used for runtime expectations under the -# default ``"light"`` strategy). -_BENCHMARKABLE_SPECS = AdversarialBenchmark._get_benchmarkable_specs() -_NUM_ADVERSARIAL_TECHNIQUES = len(_BENCHMARKABLE_SPECS) -_BENCHMARKABLE_TECHNIQUE_NAMES = {spec.name for spec in _BENCHMARKABLE_SPECS} -_BENCHMARKABLE_ATTACK_CLASSES = {spec.attack_class for spec in _BENCHMARKABLE_SPECS} - -_LIGHT_BENCHMARKABLE_SPECS = [spec for spec in _BENCHMARKABLE_SPECS if "light" in spec.strategy_tags] -_NUM_LIGHT_BENCHMARKABLE = len(_LIGHT_BENCHMARKABLE_SPECS) + +def _build_benchmarkable_factories_snapshot() -> list: + """Build the benchmarkable-factory snapshot used by module-level test constants. + + Sets up a mock ``adversarial_chat`` in ``TargetRegistry`` so factory + construction does not depend on environment variables, then filters the + canonical scenario factories by the same predicate used by + ``AdversarialBenchmark._get_benchmarkable_factories``. + """ + TargetRegistry.reset_instance() + adv = MagicMock(spec=PromptTarget) + adv.capabilities.includes.return_value = True + TargetRegistry.get_registry_singleton().register_instance(adv, name="adversarial_chat") + try: + factories = build_scenario_technique_factories() + finally: + TargetRegistry.reset_instance() + return [ + f for f in factories if f.uses_adversarial and "core" in f.strategy_tags + ] + + +# Self-pinned: any change to ``_get_benchmarkable_factories`` (or to the ``light`` tag +# membership in the canonical factory catalog) is reflected automatically — no magic numbers. +_BENCHMARKABLE_FACTORIES = _build_benchmarkable_factories_snapshot() +_NUM_ADVERSARIAL_TECHNIQUES = len(_BENCHMARKABLE_FACTORIES) +_BENCHMARKABLE_TECHNIQUE_NAMES = {f.name for f in _BENCHMARKABLE_FACTORIES} +_BENCHMARKABLE_ATTACK_CLASSES = {f.attack_class for f in _BENCHMARKABLE_FACTORIES} + +_LIGHT_BENCHMARKABLE_FACTORIES = [f for f in _BENCHMARKABLE_FACTORIES if "light" in f.strategy_tags] +_NUM_LIGHT_BENCHMARKABLE = len(_LIGHT_BENCHMARKABLE_FACTORIES) # --------------------------------------------------------------------------- # Synthetic many-shot examples — prevents reading the real JSON during tests @@ -127,12 +143,22 @@ def single_adversarial_model(): @pytest.fixture(autouse=True) def reset_technique_registry(): - """Reset the AttackTechniqueRegistry and cached strategy class between tests.""" - from pyrit.registry import TargetRegistry + """Reset registries, populate scenario factories, and clear cached strategy class. + Registers a mock adversarial target under ``adversarial_chat`` in + ``TargetRegistry`` so ``build_scenario_technique_factories`` resolves + without falling back to ``OpenAIChatTarget``. + """ AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() AdversarialBenchmark._cached_strategy_class = None + + adv_target = MagicMock(spec=PromptTarget) + adv_target.capabilities.includes.return_value = True + TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + + technique_registry = AttackTechniqueRegistry.get_registry_singleton() + technique_registry.register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() @@ -212,12 +238,6 @@ def test_default_dataset_config_max_size_is_8(self): config = AdversarialBenchmark.default_dataset_config() assert config.max_dataset_size == 8 - def test_frozen_spec_cannot_be_mutated(self): - """AttackTechniqueSpec is frozen — direct mutation must raise.""" - spec = SCENARIO_TECHNIQUES[0] - with pytest.raises(FrozenInstanceError): - spec.name = "mutated" # type: ignore[misc] - # =========================================================================== # Strategy construction tests @@ -248,10 +268,9 @@ def test_strategy_has_no_permuted_members(self): assert not any("__" in v for v in values) def test_strategy_excludes_non_adversarial_techniques(self): - """prompt_sending and many_shot don't accept an adversarial chat and must be excluded.""" + """many_shot doesn't accept an adversarial chat and must be excluded.""" strat = AdversarialBenchmark.get_strategy_class() values = {s.value for s in strat.get_all_strategies()} - assert "prompt_sending" not in values assert "many_shot" not in values def test_strategy_class_is_static(self, single_adversarial_model, two_adversarial_models): @@ -267,21 +286,22 @@ def test_default_strategy_is_light(self): assert default.value == "light" def test_benchmarkable_specs_have_no_adversarial_chat(self): - """Filtered specs must leave adversarial_chat unset — the scenario injects its own.""" - for spec in AdversarialBenchmark._get_benchmarkable_specs(): - assert spec.adversarial_chat is None + """Benchmarkable factories must be tagged ``core`` (excludes persona variants).""" + for factory in AdversarialBenchmark._get_benchmarkable_factories(): + assert "core" in factory.strategy_tags def test_benchmarkable_specs_accept_adversarial(self): - """All filtered specs must accept attack_adversarial_config.""" - for spec in AdversarialBenchmark._get_benchmarkable_specs(): - assert AttackTechniqueRegistry._accepts_adversarial(spec.attack_class) + """All filtered factories drive an adversarial chat.""" + for factory in AdversarialBenchmark._get_benchmarkable_factories(): + assert factory.uses_adversarial is True def test_original_scenario_techniques_unmodified(self, two_adversarial_models): - """SCENARIO_TECHNIQUES global must not be mutated by spec filtering.""" - original = copy.deepcopy([(s.name, s.attack_class) for s in SCENARIO_TECHNIQUES]) + """The benchmark's factory filter must not mutate the registry.""" + registry = AttackTechniqueRegistry.get_registry_singleton() + before = sorted(registry.get_names()) _make_benchmark(two_adversarial_models) - current = [(s.name, s.attack_class) for s in SCENARIO_TECHNIQUES] - assert current == original + after = sorted(registry.get_names()) + assert before == after def test_singleton_registry_not_polluted(self, two_adversarial_models): """Building atomic attacks must not register anything in the global singleton.""" diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py index 6b78a282ad..5b7d47e37b 100644 --- a/tests/unit/scenario/test_attack_technique_factory.py +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -58,21 +58,21 @@ class TestFactoryInit: """Tests for AttackTechniqueFactory construction and validation.""" def test_init_defaults(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) assert factory.attack_class is _StubAttack assert factory.seed_technique is None def test_init_stores_seed_technique(self): seeds = _make_seed_technique() - factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack, seed_technique=seeds) assert factory.seed_technique is seeds def test_validate_kwargs_accepts_valid_params(self): """All valid kwarg names should pass without error.""" factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, ) assert factory.attack_class is _StubAttack @@ -81,7 +81,7 @@ def test_validate_kwargs_rejects_unknown_params(self): """Typo or nonexistent kwarg should raise TypeError immediately.""" with pytest.raises(TypeError, match="Invalid kwargs.*max_turn"): AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turn": 10}, # typo: should be max_turns ) @@ -90,7 +90,7 @@ def test_validate_kwargs_rejects_objective_target(self): target = MagicMock(spec=PromptTarget) with pytest.raises(ValueError, match="objective_target must not be in attack_kwargs"): AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"objective_target": target}, ) @@ -98,7 +98,7 @@ def test_validate_kwargs_rejects_multiple_invalid(self): """Multiple bad kwargs should all be reported.""" with pytest.raises(TypeError, match="Invalid kwargs"): AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"bad_param_1": 1, "bad_param_2": 2}, ) @@ -110,7 +110,7 @@ def __init__(self, **kwargs): pass with pytest.raises(TypeError, match="accepts \\*\\*kwargs.*parameter validation"): - AttackTechniqueFactory(attack_class=_KwargsAttack) + AttackTechniqueFactory(name="test", attack_class=_KwargsAttack) def test_validate_kwargs_rejects_var_keyword_even_with_named_params(self): """Mixed named params + **kwargs should still be rejected.""" @@ -121,7 +121,7 @@ def __init__(self, *, objective_target, max_turns: int = 5, **extra): with pytest.raises(TypeError, match="accepts \\*\\*kwargs"): AttackTechniqueFactory( - attack_class=_MixedAttack, + name="test", attack_class=_MixedAttack, attack_kwargs={"max_turns": 10}, ) @@ -131,14 +131,14 @@ def test_validate_kwargs_works_with_real_attack_class(self): and functools.wraps on a real AttackStrategy subclass. """ # PromptSendingAttack uses @apply_defaults — factory should see its real params - factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) + factory = AttackTechniqueFactory(name="test", attack_class=PromptSendingAttack) assert factory.attack_class is PromptSendingAttack def test_validate_kwargs_rejects_invalid_param_on_real_attack_class(self): """A typo kwarg should be caught even through @apply_defaults.""" with pytest.raises(TypeError, match="Invalid kwargs.*nonexistent_param"): AttackTechniqueFactory( - attack_class=PromptSendingAttack, + name="test", attack_class=PromptSendingAttack, attack_kwargs={"nonexistent_param": 42}, ) @@ -150,7 +150,7 @@ def _scoring(self) -> AttackScoringConfig: return MagicMock(spec=AttackScoringConfig) def test_create_produces_attack_technique(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) target = MagicMock(spec=PromptTarget) technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) @@ -161,7 +161,7 @@ def test_create_produces_attack_technique(self): def test_create_passes_frozen_kwargs(self): factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 42}, ) target = MagicMock(spec=PromptTarget) @@ -171,7 +171,7 @@ def test_create_passes_frozen_kwargs(self): assert technique.attack.max_turns == 42 def test_create_passes_scoring_config(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) target = MagicMock(spec=PromptTarget) scoring = MagicMock(spec=AttackScoringConfig) @@ -183,7 +183,7 @@ def test_create_overrides_frozen_scoring_config(self): """Create-time scoring config should override the frozen one.""" frozen_scoring = MagicMock(spec=AttackScoringConfig) factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"attack_scoring_config": frozen_scoring}, ) target = MagicMock(spec=PromptTarget) @@ -196,7 +196,7 @@ def test_create_overrides_frozen_scoring_config(self): def test_create_preserves_seed_technique(self): seeds = _make_seed_technique() - factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seeds) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack, seed_technique=seeds) target = MagicMock(spec=PromptTarget) technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) @@ -206,7 +206,7 @@ def test_create_preserves_seed_technique(self): def test_create_produces_independent_instances(self): """Two create() calls should produce fully independent attack instances.""" factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 10}, ) target1 = MagicMock(spec=PromptTarget) @@ -233,7 +233,7 @@ def get_identifier(self): return ComponentIdentifier(class_name="_ListAttack", class_module="test") factory = AttackTechniqueFactory( - attack_class=_ListAttack, + name="test", attack_class=_ListAttack, attack_kwargs={"items": mutable_list}, ) target = MagicMock(spec=PromptTarget) @@ -266,7 +266,9 @@ def __init__( def get_identifier(self): return ComponentIdentifier(class_name="_SentinelAttack", class_module="test") - factory = AttackTechniqueFactory(attack_class=_SentinelAttack) + factory = AttackTechniqueFactory( + name="test", attack_class=_SentinelAttack, uses_adversarial=False + ) target = MagicMock(spec=PromptTarget) technique = factory.create(objective_target=target, attack_scoring_config=self._scoring()) @@ -278,7 +280,7 @@ class TestFactoryIdentifier: """Tests for AttackTechniqueFactory._build_identifier().""" def test_identifier_includes_attack_class_name(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) identifier = factory.get_identifier() @@ -288,7 +290,7 @@ def test_identifier_includes_attack_class_name(self): def test_identifier_includes_kwargs_with_values(self): factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, ) @@ -297,7 +299,7 @@ def test_identifier_includes_kwargs_with_values(self): assert identifier.params["kwargs"] == {"attack_scoring_config": None, "max_turns": 10} def test_identifier_empty_kwargs(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) identifier = factory.get_identifier() @@ -306,11 +308,11 @@ def test_identifier_empty_kwargs(self): def test_same_keys_different_values_produce_different_hashes(self): """Two factories with max_turns=5 vs max_turns=50 must have different hashes.""" factory1 = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 5}, ) factory2 = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 50}, ) @@ -318,11 +320,11 @@ def test_same_keys_different_values_produce_different_hashes(self): def test_different_kwargs_keys_produce_different_hashes(self): factory1 = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 10}, ) factory2 = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, attack_kwargs={"max_turns": 10, "attack_scoring_config": None}, ) @@ -346,7 +348,7 @@ def get_identifier(self): return ComponentIdentifier(class_name="_IdentifiableParamAttack", class_module="test") factory = AttackTechniqueFactory( - attack_class=_IdentifiableParamAttack, + name="test", attack_class=_IdentifiableParamAttack, attack_kwargs={"config": mock_identifiable}, ) @@ -357,7 +359,7 @@ def get_identifier(self): assert config_value == expected_id.hash def test_identifier_is_cached(self): - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) first = factory.get_identifier() second = factory.get_identifier() @@ -367,7 +369,7 @@ def test_identifier_is_cached(self): def test_seed_technique_included_in_identifier(self): """A factory with seed_technique should have technique_seeds children.""" seed_technique = _make_seed_technique() - factory = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seed_technique) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack, seed_technique=seed_technique) identifier = factory.get_identifier() @@ -376,7 +378,7 @@ def test_seed_technique_included_in_identifier(self): def test_no_seed_technique_means_no_children(self): """A factory without seed_technique should have no technique_seeds children.""" - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) identifier = factory.get_identifier() @@ -390,8 +392,8 @@ def test_different_seed_techniques_produce_different_hashes(self): seed2 = SeedAttackTechniqueGroup( seeds=[SeedPrompt(value="technique_b", data_type="text", is_general_technique=True)], ) - factory1 = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seed1) - factory2 = AttackTechniqueFactory(attack_class=_StubAttack, seed_technique=seed2) + factory1 = AttackTechniqueFactory(name="test", attack_class=_StubAttack, seed_technique=seed1) + factory2 = AttackTechniqueFactory(name="test", attack_class=_StubAttack, seed_technique=seed2) assert factory1.get_identifier().hash != factory2.get_identifier().hash @@ -401,7 +403,7 @@ class TestScorerPolicy: def test_should_apply_returns_true_when_type_compatible(self): """Config passes through when the attack accepts base AttackScoringConfig.""" - factory = AttackTechniqueFactory(attack_class=_StubAttack) + factory = AttackTechniqueFactory(name="test", attack_class=_StubAttack) config = MagicMock(spec=AttackScoringConfig) result = factory._should_apply_scoring_config( @@ -422,7 +424,7 @@ def get_identifier(self): return ComponentIdentifier(class_name="_NoScoringAttack", class_module="test") factory = AttackTechniqueFactory( - attack_class=_NoScoringAttack, + name="test", attack_class=_NoScoringAttack, scorer_override_policy=ScorerOverridePolicy.SKIP, ) config = MagicMock(spec=AttackScoringConfig) @@ -448,7 +450,7 @@ def get_identifier(self): return ComponentIdentifier(class_name="_NarrowedAttack", class_module="test") factory = AttackTechniqueFactory( - attack_class=_NarrowedAttack, + name="test", attack_class=_NarrowedAttack, scorer_override_policy=ScorerOverridePolicy.WARN, ) config = MagicMock(spec=AttackScoringConfig) @@ -475,7 +477,7 @@ def get_identifier(self): return ComponentIdentifier(class_name="_NarrowedAttack", class_module="test") factory = AttackTechniqueFactory( - attack_class=_NarrowedAttack, + name="test", attack_class=_NarrowedAttack, scorer_override_policy=ScorerOverridePolicy.RAISE, ) config = MagicMock(spec=AttackScoringConfig) @@ -500,7 +502,7 @@ def get_identifier(self): return ComponentIdentifier(class_name="_NarrowedAttack", class_module="test") factory = AttackTechniqueFactory( - attack_class=_NarrowedAttack, + name="test", attack_class=_NarrowedAttack, scorer_override_policy=ScorerOverridePolicy.RAISE, ) config = MagicMock(spec=_NarrowedScoringConfig) @@ -515,7 +517,7 @@ def get_identifier(self): def test_apply_scorer_policy_skip_is_silent(self, caplog): """SKIP policy should not log or raise.""" factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, scorer_override_policy=ScorerOverridePolicy.SKIP, ) @@ -526,7 +528,7 @@ def test_apply_scorer_policy_skip_is_silent(self, caplog): def test_apply_scorer_policy_warn_logs(self, caplog): """WARN policy should log a warning.""" factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, scorer_override_policy=ScorerOverridePolicy.WARN, ) @@ -537,7 +539,7 @@ def test_apply_scorer_policy_warn_logs(self, caplog): def test_apply_scorer_policy_raise_raises(self): """RAISE policy should raise ValueError with the message.""" factory = AttackTechniqueFactory( - attack_class=_StubAttack, + name="test", attack_class=_StubAttack, scorer_override_policy=ScorerOverridePolicy.RAISE, ) diff --git a/tests/unit/scenario/test_cyber.py b/tests/unit/scenario/test_cyber.py index d519e8913f..4f537b3d38 100644 --- a/tests/unit/scenario/test_cyber.py +++ b/tests/unit/scenario/test_cyber.py @@ -7,15 +7,17 @@ import pytest -from pyrit.executor.attack import PromptSendingAttack, RedTeamingAttack +from pyrit.executor.attack import RedTeamingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.dataset_configuration import DatasetConfiguration -from pyrit.scenario.core.scenario_techniques import register_scenario_techniques from pyrit.scenario.scenarios.airt.cyber import Cyber from pyrit.score import TrueFalseScorer +from pyrit.setup.initializers.components.scenario_techniques import ( + build_scenario_technique_factories, +) # --------------------------------------------------------------------------- # Helpers @@ -59,12 +61,26 @@ def mock_objective_scorer(): @pytest.fixture(autouse=True) def reset_technique_registry(): - """Reset the AttackTechniqueRegistry, TargetRegistry, and cached strategy class between tests.""" + """Reset registries, populate scenario factories, and clear cached strategy class. + + Registers a mock adversarial target under ``adversarial_chat`` in + ``TargetRegistry`` so ``build_scenario_technique_factories`` can resolve + it without falling back to ``OpenAIChatTarget`` (which would require + central memory). + """ from pyrit.registry import TargetRegistry AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() Cyber._cached_strategy_class = None + + adv_target = MagicMock(spec=PromptTarget) + adv_target.capabilities.includes.return_value = True + target_registry = TargetRegistry.get_registry_singleton() + target_registry.register_instance(adv_target, name="adversarial_chat") + + technique_registry = AttackTechniqueRegistry.get_registry_singleton() + technique_registry.register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() @@ -217,7 +233,7 @@ async def _init_and_get_attacks( await scenario.initialize_async(**init_kwargs) return await scenario._get_atomic_attacks_async() - async def test_all_strategy_produces_prompt_sending_and_red_teaming( + async def test_all_strategy_produces_red_teaming( self, mock_objective_target, mock_objective_scorer ): attacks = await self._init_and_get_attacks( @@ -226,16 +242,7 @@ async def test_all_strategy_produces_prompt_sending_and_red_teaming( strategies=[_strategy_class().ALL], ) technique_classes = {type(a.attack_technique.attack) for a in attacks} - assert technique_classes == {PromptSendingAttack, RedTeamingAttack} - - async def test_single_turn_strategy_produces_prompt_sending(self, mock_objective_target, mock_objective_scorer): - attacks = await self._init_and_get_attacks( - mock_objective_target=mock_objective_target, - mock_objective_scorer=mock_objective_scorer, - strategies=[_strategy_class().SINGLE_TURN], - ) - technique_classes = {type(a.attack_technique.attack) for a in attacks} - assert technique_classes == {PromptSendingAttack} + assert technique_classes == {RedTeamingAttack} async def test_multi_turn_strategy_produces_red_teaming(self, mock_objective_target, mock_objective_scorer): attacks = await self._init_and_get_attacks( @@ -246,24 +253,26 @@ async def test_multi_turn_strategy_produces_red_teaming(self, mock_objective_tar technique_classes = {type(a.attack_technique.attack) for a in attacks} assert technique_classes == {RedTeamingAttack} - async def test_default_strategy_produces_both_techniques(self, mock_objective_target, mock_objective_scorer): - """Default (ALL) should produce both PromptSending and RedTeaming.""" + async def test_default_strategy_produces_red_teaming(self, mock_objective_target, mock_objective_scorer): + """Default (ALL) should produce RedTeaming. PromptSendingAttack baseline is + prepended automatically by BaselineAttackPolicy.Enabled when + include_baseline=True (the helper here uses include_baseline=False).""" attacks = await self._init_and_get_attacks( mock_objective_target=mock_objective_target, mock_objective_scorer=mock_objective_scorer, ) technique_classes = {type(a.attack_technique.attack) for a in attacks} - assert technique_classes == {PromptSendingAttack, RedTeamingAttack} + assert technique_classes == {RedTeamingAttack} async def test_single_technique_selection(self, mock_objective_target, mock_objective_scorer): attacks = await self._init_and_get_attacks( mock_objective_target=mock_objective_target, mock_objective_scorer=mock_objective_scorer, - strategies=[_strategy_class()("prompt_sending")], + strategies=[_strategy_class()("red_teaming")], ) assert len(attacks) > 0 for a in attacks: - assert isinstance(a.attack_technique.attack, PromptSendingAttack) + assert isinstance(a.attack_technique.attack, RedTeamingAttack) async def test_atomic_attack_names_are_unique(self, mock_objective_target, mock_objective_scorer): attacks = await self._init_and_get_attacks( @@ -279,7 +288,7 @@ async def test_attacks_include_seed_groups(self, mock_objective_target, mock_obj attacks = await self._init_and_get_attacks( mock_objective_target=mock_objective_target, mock_objective_scorer=mock_objective_scorer, - strategies=[_strategy_class()("prompt_sending")], + strategies=[_strategy_class()("red_teaming")], ) for a in attacks: assert len(a.objectives) > 0 @@ -314,24 +323,24 @@ def test_cyber_strategy_resolves_from_module(self): class TestCyberRegistryIntegration: """Tests for attack technique registry wiring via Cyber scenario.""" - def test_cyber_factories_include_prompt_sending_and_red_teaming(self, mock_objective_scorer): + def test_cyber_factories_include_red_teaming(self, mock_objective_scorer): scenario = Cyber(objective_scorer=mock_objective_scorer) factories = scenario._get_attack_technique_factories() - # Cyber uses all registered techniques from the registry; prompt_sending + red_teaming are present - assert "prompt_sending" in factories + # Cyber filters the registry to red_teaming; the PromptSendingAttack baseline + # is contributed at runtime by BaselineAttackPolicy.Enabled, not by this dict. assert "red_teaming" in factories - assert factories["prompt_sending"].attack_class is PromptSendingAttack assert factories["red_teaming"].attack_class is RedTeamingAttack def test_red_teaming_factory_has_adversarial_config(self, mock_objective_scorer): - """red_teaming factory should have adversarial config baked in.""" + """red_teaming factory advertises uses_adversarial (config resolved lazily at create()).""" scenario = Cyber(objective_scorer=mock_objective_scorer) factories = scenario._get_attack_technique_factories() - assert factories["red_teaming"]._adversarial_config is not None + assert factories["red_teaming"].uses_adversarial is True + assert factories["red_teaming"]._adversarial_config is None def test_register_idempotent(self): - """Calling register_scenario_techniques twice doesn't duplicate entries.""" - register_scenario_techniques() - register_scenario_techniques() + """Registering the scenario technique factories twice doesn't duplicate entries.""" registry = AttackTechniqueRegistry.get_registry_singleton() + registry.register_from_factories(build_scenario_technique_factories()) + registry.register_from_factories(build_scenario_technique_factories()) assert len([n for n in registry.get_names() if n == "red_teaming"]) == 1 diff --git a/tests/unit/scenario/test_leakage_scenario.py b/tests/unit/scenario/test_leakage_scenario.py index c1b5659e74..7c1518632b 100644 --- a/tests/unit/scenario/test_leakage_scenario.py +++ b/tests/unit/scenario/test_leakage_scenario.py @@ -12,10 +12,13 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedAttackGroup, SeedDataset, SeedObjective from pyrit.prompt_target import PromptTarget +from pyrit.registry import TargetRegistry +from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario import DatasetConfiguration from pyrit.scenario.airt import Leakage, LeakageStrategy from pyrit.scenario.core import BaselineAttackPolicy from pyrit.score import TrueFalseCompositeScorer +from pyrit.setup.initializers.components.scenario_techniques import build_scenario_technique_factories def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ComponentIdentifier: @@ -83,6 +86,25 @@ def mock_objective_scorer(): FIXTURES = ["patch_central_database", "mock_runtime_env"] +@pytest.fixture(autouse=True) +def reset_technique_registry(): + """Reset registries and populate scenario factories for each test.""" + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + Leakage._cached_strategy_class = None + + adv_target = MagicMock(spec=PromptTarget) + adv_target.capabilities.includes.return_value = True + TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + + technique_registry = AttackTechniqueRegistry.get_registry_singleton() + technique_registry.register_from_factories(build_scenario_technique_factories()) + yield + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + Leakage._cached_strategy_class = None + + @pytest.mark.usefixtures(*FIXTURES) class TestLeakageInitialization: """Tests for Leakage initialization.""" @@ -180,7 +202,7 @@ def test_scenario_version_is_set(self, mock_objective_scorer): def test_get_strategy_class_returns_dynamic_class(self): """Test that get_strategy_class returns a dynamically generated strategy class.""" strategy_class = Leakage.get_strategy_class() - assert strategy_class is LeakageStrategy + assert strategy_class.__name__ == "LeakageStrategy" def test_get_default_strategy_returns_default(self): """Test that get_default_strategy returns the DEFAULT aggregate.""" @@ -228,5 +250,4 @@ def test_strategy_has_technique_members(self): assert "first_letter" in values assert "image" in values # Core techniques included - assert "prompt_sending" in values assert "role_play" in values diff --git a/tests/unit/scenario/test_rapid_response.py b/tests/unit/scenario/test_rapid_response.py index ecaef3d02c..9a681af853 100644 --- a/tests/unit/scenario/test_rapid_response.py +++ b/tests/unit/scenario/test_rapid_response.py @@ -18,20 +18,18 @@ ) from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt -from pyrit.prompt_target import OpenAIChatTarget, PromptTarget -from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry, AttackTechniqueSpec +from pyrit.prompt_target import PromptTarget +from pyrit.registry import TargetRegistry +from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.dataset_configuration import DatasetConfiguration -from pyrit.scenario.core.scenario_techniques import ( - SCENARIO_TECHNIQUES, - build_scenario_techniques, - get_default_adversarial_target, - register_scenario_techniques, -) from pyrit.scenario.scenarios.airt.rapid_response import ( RapidResponse, ) from pyrit.score import TrueFalseScorer +from pyrit.setup.initializers.components.scenario_techniques import ( + build_scenario_technique_factories, +) # --------------------------------------------------------------------------- # Synthetic many-shot examples — prevents reading the real JSON during tests @@ -81,12 +79,22 @@ def mock_objective_scorer(): @pytest.fixture(autouse=True) def reset_technique_registry(): - """Reset the AttackTechniqueRegistry, TargetRegistry, and cached strategy class between tests.""" - from pyrit.registry import TargetRegistry + """Reset registries, register a mock adversarial target, and populate factories. + The mock target satisfies the ``adversarial_chat`` slot so + ``build_scenario_technique_factories`` does not fall back to + ``OpenAIChatTarget``. + """ AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() RapidResponse._cached_strategy_class = None + + adv_target = MagicMock(spec=PromptTarget) + adv_target.capabilities.includes.return_value = True + TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + + technique_registry = AttackTechniqueRegistry.get_registry_singleton() + technique_registry.register_from_factories(build_scenario_technique_factories()) yield AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() @@ -251,7 +259,7 @@ async def _init_and_get_attacks( await scenario.initialize_async(**init_kwargs) return await scenario._get_atomic_attacks_async() - async def test_default_strategy_produces_prompt_sending_and_many_shot( + async def test_default_strategy_produces_role_play_and_many_shot( self, mock_objective_target, mock_objective_scorer ): attacks = await self._init_and_get_attacks( @@ -259,7 +267,7 @@ async def test_default_strategy_produces_prompt_sending_and_many_shot( mock_objective_scorer=mock_objective_scorer, ) technique_classes = {type(a.attack_technique.attack) for a in attacks} - assert technique_classes == {PromptSendingAttack, ManyShotJailbreakAttack} + assert technique_classes == {RolePlayAttack, ManyShotJailbreakAttack} async def test_single_turn_strategy_produces_single_turn_attacks( self, mock_objective_target, mock_objective_scorer @@ -270,7 +278,7 @@ async def test_single_turn_strategy_produces_single_turn_attacks( strategies=[_strategy_class().SINGLE_TURN], ) technique_classes = {type(a.attack_technique.attack) for a in attacks} - # Every core technique tagged ``single_turn`` in SCENARIO_TECHNIQUES must appear. + # Every core technique tagged ``single_turn`` in the scenario-technique catalog must appear. assert {PromptSendingAttack, RolePlayAttack, ContextComplianceAttack} <= technique_classes # And no multi-turn-only attack should leak in. assert ManyShotJailbreakAttack not in technique_classes @@ -307,11 +315,11 @@ async def test_single_technique_selection(self, mock_objective_target, mock_obje attacks = await self._init_and_get_attacks( mock_objective_target=mock_objective_target, mock_objective_scorer=mock_objective_scorer, - strategies=[_strategy_class()("prompt_sending")], + strategies=[_strategy_class()("role_play")], ) assert len(attacks) > 0 for a in attacks: - assert isinstance(a.attack_technique.attack, PromptSendingAttack) + assert isinstance(a.attack_technique.attack, RolePlayAttack) async def test_attack_count_is_techniques_times_datasets(self, mock_objective_target, mock_objective_scorer): """With 2 datasets and DEFAULT (2 techniques), expect 4 atomic attacks.""" @@ -324,7 +332,7 @@ async def test_attack_count_is_techniques_times_datasets(self, mock_objective_ta mock_objective_scorer=mock_objective_scorer, seed_groups=two_datasets, ) - # DEFAULT = PromptSending + ManyShot = 2 techniques, 2 datasets → 4 + # DEFAULT = RolePlay + ManyShot = 2 techniques, 2 datasets → 4 assert len(attacks) == 4 async def test_atomic_attack_names_are_unique_compound_keys(self, mock_objective_target, mock_objective_scorer): @@ -370,21 +378,22 @@ async def test_unknown_technique_skipped_with_warning(self, mock_objective_targe """If a technique name has no factory, it's skipped (not an error).""" groups = {"hate": _make_seed_groups("hate")} - # Register only prompt_sending in the registry — the other techniques + # Reset the registry and register only prompt_sending — the other techniques # (role_play, many_shot, tap) won't have factories. + AttackTechniqueRegistry.reset_instance() + RapidResponse._cached_strategy_class = None registry = AttackTechniqueRegistry.get_registry_singleton() registry.register_technique( name="prompt_sending", - factory=AttackTechniqueFactory(attack_class=PromptSendingAttack), - tags=["single_turn"], + factory=AttackTechniqueFactory( + name="prompt_sending", + attack_class=PromptSendingAttack, + strategy_tags=["core", "single_turn"], + ), + tags=["core", "single_turn"], ) - with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), - patch( - "pyrit.scenario.core.scenario_techniques.register_scenario_techniques", - ), - ): + with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): scenario = RapidResponse( objective_scorer=mock_objective_scorer, ) @@ -404,7 +413,7 @@ async def test_attacks_include_seed_groups(self, mock_objective_target, mock_obj attacks = await self._init_and_get_attacks( mock_objective_target=mock_objective_target, mock_objective_scorer=mock_objective_scorer, - strategies=[_strategy_class()("prompt_sending")], + strategies=[_strategy_class()("role_play")], ) for a in attacks: assert len(a.objectives) > 0 @@ -445,28 +454,33 @@ class TestCoreTechniques: def test_instance_returns_all_factories(self, mock_objective_scorer): scenario = RapidResponse(objective_scorer=mock_objective_scorer) factories = scenario._get_attack_technique_factories() - assert {"prompt_sending", "role_play", "many_shot", "tap"} <= set(factories.keys()) - assert factories["prompt_sending"].attack_class is PromptSendingAttack + assert {"role_play", "many_shot", "tap"} <= set(factories.keys()) assert factories["role_play"].attack_class is RolePlayAttack assert factories["many_shot"].attack_class is ManyShotJailbreakAttack assert factories["tap"].attack_class is TreeOfAttacksWithPruningAttack def test_factories_use_default_adversarial_when_none(self, mock_objective_scorer): - """Factories use get_default_adversarial_target for adversarial config.""" + """Factories that need an adversarial chat mark themselves as adversarial. + + The default adversarial target is resolved lazily inside ``create()``; + it is not baked into the factory at construction time. + """ scenario = RapidResponse(objective_scorer=mock_objective_scorer) factories = scenario._get_attack_technique_factories() - # role_play and tap should have adversarial_config as first-class field - assert factories["role_play"]._adversarial_config is not None - assert factories["tap"]._adversarial_config is not None + assert factories["role_play"].uses_adversarial is True + assert factories["tap"].uses_adversarial is True + assert factories["role_play"]._adversarial_config is None + assert factories["tap"]._adversarial_config is None def test_factories_always_use_default_adversarial(self, mock_objective_scorer): - """Registry always bakes default adversarial target from get_default_adversarial_target.""" + """Factories defer adversarial wiring to create()-time lazy resolution.""" scenario = RapidResponse(objective_scorer=mock_objective_scorer) factories = scenario._get_attack_technique_factories() - # Factories have an adversarial config as first-class field - assert factories["role_play"]._adversarial_config is not None - assert factories["tap"]._adversarial_config is not None + assert factories["role_play"]._adversarial_config is None + assert factories["tap"]._adversarial_config is None + assert factories["role_play"]._adversarial_config_was_explicit is False + assert factories["tap"]._adversarial_config_was_explicit is False # =========================================================================== @@ -506,351 +520,117 @@ def test_content_harms_instance_name_is_rapid_response(self, mock_objective_scor @pytest.mark.usefixtures(*FIXTURES) class TestRegistryIntegration: - """Tests for AttackTechniqueRegistry wiring via register_scenario_techniques.""" + """Tests for AttackTechniqueRegistry wiring via build_scenario_technique_factories.""" - def test_register_populates_registry(self, mock_adversarial_target): - """After calling register_scenario_techniques(), all 4 techniques are in registry.""" - register_scenario_techniques() + def test_registry_populated_by_autouse_fixture(self): + """The autouse fixture registers all canonical scenario techniques.""" registry = AttackTechniqueRegistry.get_registry_singleton() names = set(registry.get_names()) - assert {"prompt_sending", "role_play", "many_shot", "tap"} <= names + assert {"role_play", "many_shot", "tap"} <= names - def test_register_idempotent(self, mock_adversarial_target): - """Calling register_scenario_techniques() twice doesn't duplicate entries.""" - register_scenario_techniques() - register_scenario_techniques() + def test_register_from_factories_idempotent(self): + """Calling register_from_factories twice does not duplicate entries.""" registry = AttackTechniqueRegistry.get_registry_singleton() - assert len(registry) == len(SCENARIO_TECHNIQUES) + expected = len(build_scenario_technique_factories()) + registry.register_from_factories(build_scenario_technique_factories()) + assert len(registry) == expected - def test_register_preserves_custom(self, mock_adversarial_target): - """Pre-registered custom techniques aren't overwritten.""" + def test_register_preserves_custom_preregistered(self): + """Pre-registered custom techniques are not overwritten by re-registration.""" registry = AttackTechniqueRegistry.get_registry_singleton() - custom_factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) + custom_factory = AttackTechniqueFactory(name="role_play", attack_class=PromptSendingAttack) registry.register_technique(name="role_play", factory=custom_factory, tags=["custom"]) - register_scenario_techniques() - - # role_play should still be the custom factory - factories = registry.get_factories() - assert factories["role_play"] is custom_factory - assert len(factories) == len(SCENARIO_TECHNIQUES) + registry.register_from_factories(build_scenario_technique_factories()) + assert registry.get_factories()["role_play"] is custom_factory - def test_get_factories_returns_dict(self, mock_adversarial_target): - """get_factories() returns a dict of name → factory.""" - register_scenario_techniques() + def test_get_factories_returns_dict(self): registry = AttackTechniqueRegistry.get_registry_singleton() factories = registry.get_factories() assert isinstance(factories, dict) - assert {"prompt_sending", "role_play", "many_shot", "tap"} <= set(factories.keys()) - assert factories["prompt_sending"].attack_class is PromptSendingAttack + assert {"role_play", "many_shot", "tap"} <= set(factories.keys()) + assert factories["role_play"].attack_class is RolePlayAttack def test_scenario_base_class_reads_from_registry(self, mock_objective_scorer): - """Scenario._get_attack_technique_factories() triggers registration and reads from registry.""" + """Scenario._get_attack_technique_factories() reads from the registry.""" scenario = RapidResponse(objective_scorer=mock_objective_scorer) factories = scenario._get_attack_technique_factories() + assert {"role_play", "many_shot", "tap"} <= set(factories.keys()) - # Should have all core techniques from the registry - assert {"prompt_sending", "role_play", "many_shot", "tap"} <= set(factories.keys()) - - # Registry should also have them + def test_tags_assigned_correctly(self): registry = AttackTechniqueRegistry.get_registry_singleton() - assert {"prompt_sending", "role_play", "many_shot", "tap"} <= set(registry.get_names()) - - def test_tags_assigned_correctly(self, mock_adversarial_target): - """Core techniques have correct tags (single_turn / multi_turn).""" - register_scenario_techniques() - registry = AttackTechniqueRegistry.get_registry_singleton() - single_turn = {e.name for e in registry.get_by_tag(tag="single_turn")} multi_turn = {e.name for e in registry.get_by_tag(tag="multi_turn")} - - assert {"prompt_sending", "role_play"} <= single_turn + assert {"role_play"} <= single_turn assert {"many_shot", "tap"} <= multi_turn # =========================================================================== -# Registration and factory-from-spec tests +# build_scenario_technique_factories tests # =========================================================================== @pytest.mark.usefixtures(*FIXTURES) -class TestRegistrationAndFactoryFromSpec: - """Tests for register_scenario_techniques and AttackTechniqueRegistry.build_factory_from_spec.""" +class TestBuildScenarioTechniqueFactories: + """Tests for build_scenario_technique_factories() — the canonical factory catalog.""" + + def test_returns_nonempty_factory_list(self): + factories = build_scenario_technique_factories() + assert len(factories) >= 4 + names = [f.name for f in factories] + assert len(names) == len(set(names)), "Duplicate technique names" + + def test_adversarial_factories_have_adversarial_config(self): + """Factories that need an adversarial chat advertise it via uses_adversarial. + + The config itself is resolved lazily at create()-time. + """ + by_name = {f.name: f for f in build_scenario_technique_factories()} + assert by_name["role_play"].uses_adversarial is True + assert by_name["tap"].uses_adversarial is True + assert by_name["role_play"]._adversarial_config is None + assert by_name["tap"]._adversarial_config is None + + def test_non_adversarial_factories_have_no_adversarial_config(self): + by_name = {f.name: f for f in build_scenario_technique_factories()} + assert by_name["many_shot"]._adversarial_config is None - def test_register_populates_all_techniques(self): - """register_scenario_techniques registers all catalog techniques.""" - register_scenario_techniques() - registry = AttackTechniqueRegistry.get_registry_singleton() - registered = set(registry.get_names()) - expected = {s.name for s in SCENARIO_TECHNIQUES} - assert registered == expected - - def test_register_with_custom_adversarial_uses_default(self, mock_adversarial_target): - """Registry always bakes default adversarial target, not caller-specific.""" - register_scenario_techniques() - registry = AttackTechniqueRegistry.get_registry_singleton() - factories = registry.get_factories() - - # role_play and tap should have an adversarial config as first-class field - assert factories["role_play"]._adversarial_config is not None - assert factories["tap"]._adversarial_config is not None - - def test_register_idempotent(self, mock_adversarial_target): - """Calling register_scenario_techniques() twice does not duplicate or overwrite entries.""" - register_scenario_techniques() - register_scenario_techniques() - registry = AttackTechniqueRegistry.get_registry_singleton() - assert len(registry) == len(SCENARIO_TECHNIQUES) - - def test_register_preserves_custom_preregistered(self, mock_adversarial_target): - """Pre-registered custom techniques are not overwritten.""" - registry = AttackTechniqueRegistry.get_registry_singleton() - custom_factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) - registry.register_technique(name="role_play", factory=custom_factory, tags=["custom"]) - - register_scenario_techniques() - # role_play should still be the custom factory - assert registry.get_factories()["role_play"] is custom_factory - assert len(registry) == len(SCENARIO_TECHNIQUES) - - def test_register_assigns_correct_tags(self, mock_adversarial_target): - """Tags from AttackTechniqueSpec are applied correctly.""" - register_scenario_techniques() - registry = AttackTechniqueRegistry.get_registry_singleton() - - single_turn = {e.name for e in registry.get_by_tag(tag="single_turn")} - multi_turn = {e.name for e in registry.get_by_tag(tag="multi_turn")} - assert {"prompt_sending", "role_play"} <= single_turn - assert {"many_shot", "tap"} <= multi_turn - - def test_register_from_specs_custom_list(self, mock_adversarial_target): - """register_from_specs accepts a custom list of AttackTechniqueSpecs.""" - custom_specs = [ - AttackTechniqueSpec(name="custom_attack", attack_class=PromptSendingAttack, strategy_tags=["custom"]), - ] - registry = AttackTechniqueRegistry.get_registry_singleton() - registry.register_from_specs(custom_specs) - assert set(registry.get_names()) == {"custom_attack"} - - def test_get_default_adversarial_target_from_registry(self, mock_adversarial_target): - """get_default_adversarial_target returns registry entry when available.""" - from pyrit.registry import TargetRegistry - - target_registry = TargetRegistry.get_registry_singleton() - target_registry.register(name="adversarial_chat", instance=mock_adversarial_target) - result = get_default_adversarial_target() - assert result is mock_adversarial_target - - def test_get_default_adversarial_target_fallback(self): - """get_default_adversarial_target falls back to OpenAIChatTarget when not in registry.""" - result = get_default_adversarial_target() - assert isinstance(result, OpenAIChatTarget) - assert result._temperature == 1.2 + def test_crescendo_simulated_has_seed_technique(self): + by_name = {f.name: f for f in build_scenario_technique_factories()} + assert by_name["crescendo_simulated"].seed_technique is not None - def test_get_default_adversarial_target_capability_check(self): - """get_default_adversarial_target rejects targets without multi-turn support.""" - from pyrit.registry import TargetRegistry + def test_crescendo_simulated_has_adversarial_chat(self): + by_name = {f.name: f for f in build_scenario_technique_factories()} + assert by_name["crescendo_simulated"].uses_adversarial is True - target_registry = TargetRegistry.get_registry_singleton() - # Register a plain PromptTarget (lacks multi-turn capability) - mock_target = MagicMock(spec=PromptTarget) - mock_target.capabilities.includes.return_value = False - target_registry.register(name="adversarial_chat", instance=mock_target) - with pytest.raises(ValueError, match="must support"): - get_default_adversarial_target() + def test_extra_kwargs_preserved_on_role_play(self): + by_name = {f.name: f for f in build_scenario_technique_factories()} + assert "role_play_definition_path" in (by_name["role_play"]._attack_kwargs or {}) # =========================================================================== -# build_scenario_techniques tests +# AttackTechniqueFactory tests # =========================================================================== @pytest.mark.usefixtures(*FIXTURES) -class TestBuildScenarioTechniques: - """Tests for build_scenario_techniques() — the runtime spec transform.""" - - def test_returns_same_count_as_static_catalog(self): - specs = build_scenario_techniques() - assert len(specs) == len(SCENARIO_TECHNIQUES) - - def test_adversarial_specs_get_target(self): - specs = build_scenario_techniques() - by_name = {s.name: s for s in specs} - assert by_name["role_play"].adversarial_chat is not None - assert by_name["tap"].adversarial_chat is not None - - def test_non_adversarial_specs_unchanged(self): - specs = build_scenario_techniques() - by_name = {s.name: s for s in specs} - assert by_name["prompt_sending"].adversarial_chat is None - assert by_name["many_shot"].adversarial_chat is None - - def test_crescendo_simulated_has_seed_technique(self): - """crescendo_simulated spec declares a seed_technique.""" - by_name = {s.name: s for s in SCENARIO_TECHNIQUES} - spec = by_name["crescendo_simulated"] - assert spec.seed_technique is not None - - def test_crescendo_simulated_factory_has_adversarial_chat(self, mock_adversarial_target): - """After build_scenario_techniques, crescendo_simulated gets adversarial_chat from default.""" - register_scenario_techniques() - registry = AttackTechniqueRegistry.get_registry_singleton() - factories = registry.get_factories() - factory = factories["crescendo_simulated"] - assert factory.adversarial_chat is not None - - def test_extra_kwargs_preserved(self): - specs = build_scenario_techniques() - by_name = {s.name: s for s in specs} - assert "role_play_definition_path" in by_name["role_play"].extra_kwargs - - def test_derived_from_static_catalog(self): - """build_scenario_techniques is a transform of SCENARIO_TECHNIQUES — names match.""" - runtime_names = {s.name for s in build_scenario_techniques()} - static_names = {s.name for s in SCENARIO_TECHNIQUES} - assert runtime_names == static_names - - def test_adversarial_chat_key_resolves_from_registry(self, mock_adversarial_target): - """When adversarial_chat_key is set, it resolves the target from TargetRegistry.""" - from pyrit.registry import TargetRegistry - - registry = TargetRegistry.get_registry_singleton() - registry.register_instance(mock_adversarial_target, name="custom_adversarial") - - original = SCENARIO_TECHNIQUES.copy() - custom_spec = AttackTechniqueSpec( - name="tap", - attack_class=TreeOfAttacksWithPruningAttack, - strategy_tags=["core", "multi_turn"], - adversarial_chat_key="custom_adversarial", - ) - try: - SCENARIO_TECHNIQUES.clear() - SCENARIO_TECHNIQUES.append(custom_spec) - - specs = build_scenario_techniques() - assert specs[0].adversarial_chat is mock_adversarial_target - finally: - SCENARIO_TECHNIQUES.clear() - SCENARIO_TECHNIQUES.extend(original) - - def test_adversarial_chat_key_missing_raises(self): - """When adversarial_chat_key references a missing registry entry, ValueError is raised.""" - original = SCENARIO_TECHNIQUES.copy() - custom_spec = AttackTechniqueSpec( - name="tap", - attack_class=TreeOfAttacksWithPruningAttack, - strategy_tags=["core", "multi_turn"], - adversarial_chat_key="nonexistent_key", - ) - try: - SCENARIO_TECHNIQUES.clear() - SCENARIO_TECHNIQUES.append(custom_spec) - - with pytest.raises(ValueError, match="no such entry exists in TargetRegistry"): - build_scenario_techniques() - finally: - SCENARIO_TECHNIQUES.clear() - SCENARIO_TECHNIQUES.extend(original) - - -# =========================================================================== -# AttackTechniqueSpec tests -# =========================================================================== +class TestAttackTechniqueFactoryBasics: + """Tests for the AttackTechniqueFactory construction surface.""" - -@pytest.mark.usefixtures(*FIXTURES) -class TestAttackTechniqueSpec: - """Tests for the AttackTechniqueSpec dataclass.""" - - def test_simple_spec(self): - spec = AttackTechniqueSpec(name="test", attack_class=PromptSendingAttack, strategy_tags=["single_turn"]) - assert spec.name == "test" - assert spec.attack_class is PromptSendingAttack - assert spec.strategy_tags == ["single_turn"] - assert spec.adversarial_chat is None - assert spec.extra_kwargs == {} - - def test_extra_kwargs(self, mock_adversarial_target): - spec = AttackTechniqueSpec( - name="complex", - attack_class=RolePlayAttack, - strategy_tags=["single_turn"], - adversarial_chat=mock_adversarial_target, - extra_kwargs={"role_play_definition_path": "/custom/path.yaml"}, + def test_simple_factory(self): + factory = AttackTechniqueFactory( + name="test", attack_class=PromptSendingAttack, strategy_tags=["single_turn"] ) - factory = AttackTechniqueRegistry.build_factory_from_spec(spec) - assert factory._attack_kwargs["role_play_definition_path"] == "/custom/path.yaml" - assert factory._adversarial_config is not None - - def test_build_factory_no_adversarial_injected_when_attack_does_not_accept_it(self, mock_adversarial_target): - """adversarial config is stored on factory but not injected into attack_kwargs for non-adversarial attacks.""" - spec = AttackTechniqueSpec( - name="simple", - attack_class=PromptSendingAttack, - strategy_tags=[], - adversarial_chat=mock_adversarial_target, - ) - factory = AttackTechniqueRegistry.build_factory_from_spec(spec) - # Config is stored as first-class field (available via factory.adversarial_chat) - assert factory._adversarial_config is not None - # But NOT injected into attack_kwargs since PromptSendingAttack doesn't accept it - assert "attack_adversarial_config" not in (factory._attack_kwargs or {}) - - def test_extra_kwargs_reserved_key_raises(self): - """attack_adversarial_config must not appear in extra_kwargs.""" - spec = AttackTechniqueSpec( - name="bad", - attack_class=RolePlayAttack, - strategy_tags=[], - extra_kwargs={"attack_adversarial_config": "oops"}, - ) - with pytest.raises(ValueError, match="attack_adversarial_config"): - AttackTechniqueRegistry.build_factory_from_spec(spec) + assert factory.name == "test" + assert factory.attack_class is PromptSendingAttack + assert factory.strategy_tags == ["single_turn"] + assert factory.adversarial_chat is None def test_adversarial_config_rejected_in_attack_kwargs(self): """attack_adversarial_config in attack_kwargs raises ValueError at factory construction.""" with pytest.raises(ValueError, match="attack_adversarial_config"): AttackTechniqueFactory( + name="bad", attack_class=RolePlayAttack, attack_kwargs={"attack_adversarial_config": "oops"}, ) - - def test_scenario_techniques_list_nonempty_with_unique_names(self): - assert len(SCENARIO_TECHNIQUES) >= 1 - names = [s.name for s in SCENARIO_TECHNIQUES] - assert len(names) == len(set(names)), "Duplicate technique names in SCENARIO_TECHNIQUES" - - def test_frozen_spec(self): - """AttackTechniqueSpec is frozen (immutable).""" - spec = AttackTechniqueSpec(name="test", attack_class=PromptSendingAttack) - with pytest.raises(AttributeError): - spec.name = "modified" - - def test_adversarial_injected_when_attack_accepts_it(self, mock_adversarial_target): - """Adversarial config is injected based on attack class signature.""" - # RolePlayAttack accepts attack_adversarial_config → injected as first-class field - rp_spec = AttackTechniqueSpec( - name="rp", attack_class=RolePlayAttack, strategy_tags=[], adversarial_chat=mock_adversarial_target - ) - rp_factory = AttackTechniqueRegistry.build_factory_from_spec(rp_spec) - assert rp_factory._adversarial_config is not None - - # PromptSendingAttack does NOT accept it → config stored but not in attack_kwargs - ps_spec = AttackTechniqueSpec( - name="ps", attack_class=PromptSendingAttack, strategy_tags=[], adversarial_chat=mock_adversarial_target - ) - ps_factory = AttackTechniqueRegistry.build_factory_from_spec(ps_spec) - assert ps_factory._adversarial_config is not None - assert "attack_adversarial_config" not in (ps_factory._attack_kwargs or {}) - - def test_adversarial_chat_and_key_both_set_raises(self, mock_adversarial_target): - """Setting both adversarial_chat and adversarial_chat_key raises ValueError at construction.""" - with pytest.raises(ValueError, match="mutually exclusive"): - AttackTechniqueSpec( - name="tap", - attack_class=TreeOfAttacksWithPruningAttack, - strategy_tags=["core", "multi_turn"], - adversarial_chat=mock_adversarial_target, - adversarial_chat_key="some_key", - ) diff --git a/tests/unit/scenario/test_scenario_strategy_invariants.py b/tests/unit/scenario/test_scenario_strategy_invariants.py index 3fa9bbaa86..b5af069bc2 100644 --- a/tests/unit/scenario/test_scenario_strategy_invariants.py +++ b/tests/unit/scenario/test_scenario_strategy_invariants.py @@ -29,15 +29,26 @@ @pytest.fixture(autouse=True) def _reset_registries(): - """Reset singletons and cached strategy classes between every test.""" + """Reset singletons, populate factories, and clear cached strategy classes between tests.""" + from unittest.mock import MagicMock + + from pyrit.prompt_target import PromptTarget from pyrit.registry import TargetRegistry from pyrit.scenario.scenarios.airt.cyber import Cyber from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse + from pyrit.setup.initializers.components.scenario_techniques import build_scenario_technique_factories AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() Cyber._cached_strategy_class = None RapidResponse._cached_strategy_class = None + + adv_target = MagicMock(spec=PromptTarget) + adv_target.capabilities.includes.return_value = True + TargetRegistry.get_registry_singleton().register_instance(adv_target, name="adversarial_chat") + AttackTechniqueRegistry.get_registry_singleton().register_from_factories( + build_scenario_technique_factories() + ) yield AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() diff --git a/tests/unit/setup/test_scenarios_initializer.py b/tests/unit/setup/test_scenario_techniques_initializer.py similarity index 70% rename from tests/unit/setup/test_scenarios_initializer.py rename to tests/unit/setup/test_scenario_techniques_initializer.py index bc1f909448..75fd1e2ce9 100644 --- a/tests/unit/setup/test_scenarios_initializer.py +++ b/tests/unit/setup/test_scenario_techniques_initializer.py @@ -15,14 +15,16 @@ from pyrit.registry import TargetRegistry from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.setup.initializers import ScenarioTechniqueInitializer -from pyrit.setup.initializers.components.scenarios import ( - CRESCENDO_HISTORY_LECTURE, - CRESCENDO_JOURNALIST_INTERVIEW, - CRESCENDO_MOVIE_DIRECTOR, - PERSONA_CRESCENDO_TECHNIQUE_NAMES, - build_persona_crescendo_specs, +from pyrit.setup.initializers.components.scenario_techniques import ( + build_scenario_technique_factories, ) +PERSONA_CRESCENDO_TECHNIQUE_NAMES: list[str] = [ + "crescendo_movie_director", + "crescendo_history_lecture", + "crescendo_journalist_interview", +] + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -40,7 +42,7 @@ def reset_registries(): @pytest.fixture def mock_adversarial_target(): - """A mock adversarial target registered as 'adversarial_chat' so build_scenario_techniques resolves cleanly.""" + """A mock adversarial target registered as 'adversarial_chat' so the initializer resolves cleanly.""" target = MagicMock(spec=PromptTarget) # capabilities check inside get_default_adversarial_target requires multi_turn support target.capabilities.includes.return_value = True @@ -72,55 +74,61 @@ def test_description_from_docstring(self): # --------------------------------------------------------------------------- -# Spec construction +# Factory construction # --------------------------------------------------------------------------- -class TestPersonaCrescendoSpecs: - """Tests for build_persona_crescendo_specs.""" +class TestPersonaCrescendoFactories: + """Tests for the persona-driven crescendo entries in the canonical factory list.""" + + @staticmethod + def _persona_factories(adversarial_target): + """Build the canonical catalog and pluck out the persona variants.""" + all_factories = build_scenario_technique_factories() + return [f for f in all_factories if f.name in PERSONA_CRESCENDO_TECHNIQUE_NAMES] - def test_returns_three_specs(self): - specs = build_persona_crescendo_specs() - assert len(specs) == 3 + def test_returns_three_factories(self, mock_adversarial_target): + factories = self._persona_factories(mock_adversarial_target) + assert len(factories) == 3 - def test_names_are_persona_variants(self): - specs = build_persona_crescendo_specs() - names = {s.name for s in specs} + def test_names_are_persona_variants(self, mock_adversarial_target): + factories = self._persona_factories(mock_adversarial_target) + names = {f.name for f in factories} assert names == { - CRESCENDO_MOVIE_DIRECTOR, - CRESCENDO_HISTORY_LECTURE, - CRESCENDO_JOURNALIST_INTERVIEW, + "crescendo_movie_director", + "crescendo_history_lecture", + "crescendo_journalist_interview", } - def test_all_use_prompt_sending_attack(self): - specs = build_persona_crescendo_specs() - for spec in specs: - assert spec.attack_class is PromptSendingAttack + def test_all_use_prompt_sending_attack(self, mock_adversarial_target): + factories = self._persona_factories(mock_adversarial_target) + for f in factories: + assert f.attack_class is PromptSendingAttack - def test_all_have_seed_technique_with_simulated_conversation(self): - specs = build_persona_crescendo_specs() - for spec in specs: - assert spec.seed_technique is not None - assert spec.seed_technique.has_simulated_conversation + def test_all_have_seed_technique_with_simulated_conversation(self, mock_adversarial_target): + factories = self._persona_factories(mock_adversarial_target) + for f in factories: + assert f.seed_technique is not None + assert f.seed_technique.has_simulated_conversation - def test_all_tagged_core_single_turn(self): - specs = build_persona_crescendo_specs() - for spec in specs: - assert "core" in spec.strategy_tags - assert "single_turn" in spec.strategy_tags + def test_all_tagged_core_single_turn(self, mock_adversarial_target): + factories = self._persona_factories(mock_adversarial_target) + for f in factories: + assert "core" in f.strategy_tags + assert "single_turn" in f.strategy_tags - def test_seed_technique_num_turns_matches_canonical_default(self): + def test_seed_technique_num_turns_matches_canonical_default(self, mock_adversarial_target): """Persona variants share the canonical num_turns=3 of crescendo_simulated.""" - specs = build_persona_crescendo_specs() - for spec in specs: - sim = spec.seed_technique.simulated_conversation_config + factories = self._persona_factories(mock_adversarial_target) + for f in factories: + sim = f.seed_technique.simulated_conversation_config assert sim is not None assert sim.num_turns == 3 - def test_seed_technique_yaml_path_resolves_to_existing_file(self): - specs = build_persona_crescendo_specs() - for spec in specs: - sim = spec.seed_technique.simulated_conversation_config + def test_seed_technique_yaml_path_resolves_to_existing_file(self, mock_adversarial_target): + factories = self._persona_factories(mock_adversarial_target) + for f in factories: + sim = f.seed_technique.simulated_conversation_config assert sim is not None assert sim.adversarial_chat_system_prompt_path.exists() @@ -178,31 +186,31 @@ async def test_registers_all_three_persona_techniques(self, mock_adversarial_tar registry = AttackTechniqueRegistry.get_registry_singleton() names = set(registry.get_names()) - assert CRESCENDO_MOVIE_DIRECTOR in names - assert CRESCENDO_HISTORY_LECTURE in names - assert CRESCENDO_JOURNALIST_INTERVIEW in names + assert "crescendo_movie_director" in names + assert "crescendo_history_lecture" in names + assert "crescendo_journalist_interview" in names @pytest.mark.asyncio async def test_also_registers_core_techniques(self, mock_adversarial_target): - """Initializer first calls register_scenario_techniques() to ensure core specs land.""" + """Initializer also registers the core factories alongside persona variants.""" init = ScenarioTechniqueInitializer() await init.initialize_async() registry = AttackTechniqueRegistry.get_registry_singleton() names = set(registry.get_names()) - # Core specs from PR #1665 era catalog - assert {"prompt_sending", "role_play", "many_shot", "tap", "crescendo_simulated"} <= names + # Core factories from build_scenario_technique_factories() + assert {"role_play", "many_shot", "tap", "crescendo_simulated"} <= names @pytest.mark.asyncio async def test_persona_factories_have_adversarial_config(self, mock_adversarial_target): - """Each persona factory has an adversarial config baked in (mirrors crescendo_simulated).""" + """Each persona factory marks itself as adversarial (lazy-resolves a chat in create()).""" init = ScenarioTechniqueInitializer() await init.initialize_async() registry = AttackTechniqueRegistry.get_registry_singleton() factories = registry.get_factories() for name in PERSONA_CRESCENDO_TECHNIQUE_NAMES: - assert factories[name].adversarial_chat is not None + assert factories[name].uses_adversarial is True @pytest.mark.asyncio async def test_persona_factories_carry_seed_technique(self, mock_adversarial_target): @@ -222,11 +230,11 @@ async def test_idempotent(self, mock_adversarial_target): registry = AttackTechniqueRegistry.get_registry_singleton() first_names = set(registry.get_names()) - first_factory = registry.get_factories()[CRESCENDO_MOVIE_DIRECTOR] + first_factory = registry.get_factories()["crescendo_movie_director"] await init.initialize_async() second_names = set(registry.get_names()) - second_factory = registry.get_factories()[CRESCENDO_MOVIE_DIRECTOR] + second_factory = registry.get_factories()["crescendo_movie_director"] assert first_names == second_names # Per-name idempotency: existing factory is preserved. @@ -234,9 +242,9 @@ async def test_idempotent(self, mock_adversarial_target): @pytest.mark.asyncio async def test_falls_back_to_default_target_when_registry_empty(self): - """With no 'adversarial_chat' in TargetRegistry, the fallback constructs an OpenAIChatTarget.""" - # Patch OpenAIChatTarget at the fallback construction site so the test - # does not depend on OPENAI_CHAT_MODEL or any other env var being set. + """With no 'adversarial_chat' in TargetRegistry, lazy resolution at create()-time + falls back to OpenAIChatTarget(temperature=1.2). + """ fallback_target = MagicMock(spec=PromptTarget) with patch( "pyrit.scenario.core.scenario_target_defaults.OpenAIChatTarget", @@ -245,14 +253,17 @@ async def test_falls_back_to_default_target_when_registry_empty(self): init = ScenarioTechniqueInitializer() await init.initialize_async() - # Fallback was taken: OpenAIChatTarget(temperature=1.2) was called - # at least once during get_default_adversarial_target resolution. - mock_openai.assert_any_call(temperature=1.2) + # Construction is now decoupled from adversarial resolution. + mock_openai.assert_not_called() + # Trigger the lazy fallback path explicitly. registry = AttackTechniqueRegistry.get_registry_singleton() factories = registry.get_factories() for name in PERSONA_CRESCENDO_TECHNIQUE_NAMES: - assert factories[name].adversarial_chat is fallback_target + config = factories[name]._resolve_default_adversarial_config() + assert config.target is fallback_target + + mock_openai.assert_any_call(temperature=1.2) # ---------------------------------------------------------------------------