From 849d39b573276535c507339554c21db7a182cc89 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 22 May 2026 18:22:29 -0700 Subject: [PATCH] MAINT: Simplifying scenario class vars Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../instructions/scenarios.instructions.md | 31 ++- doc/code/scenarios/0_scenarios.ipynb | 252 +----------------- doc/code/scenarios/0_scenarios.py | 24 +- doc/scanner/1_pyrit_scan.ipynb | 15 +- doc/scanner/1_pyrit_scan.py | 15 +- .../backend/services/scenario_run_service.py | 22 +- .../class_registries/scenario_registry.py | 40 ++- pyrit/scenario/core/scenario.py | 63 ++--- pyrit/scenario/scenarios/airt/__init__.py | 14 +- .../scenario/scenarios/airt/content_harms.py | 5 +- pyrit/scenario/scenarios/airt/cyber.py | 44 +-- pyrit/scenario/scenarios/airt/jailbreak.py | 32 +-- pyrit/scenario/scenarios/airt/leakage.py | 34 +-- pyrit/scenario/scenarios/airt/psychosocial.py | 32 +-- .../scenario/scenarios/airt/rapid_response.py | 66 ++--- pyrit/scenario/scenarios/airt/scam.py | 32 +-- .../scenario/scenarios/benchmark/__init__.py | 2 +- .../scenarios/benchmark/adversarial.py | 144 +++++----- .../scenarios/foundry/red_team_agent.py | 31 +-- pyrit/scenario/scenarios/garak/encoding.py | 39 +-- .../scenarios/load_default_datasets.py | 21 +- .../scenarios/preload_scenario_metadata.py | 59 ++++ .../unit/backend/test_scenario_run_service.py | 8 +- tests/unit/scenario/test_adversarial.py | 29 +- .../scenario/test_baseline_deprecation.py | 18 +- tests/unit/scenario/test_cyber.py | 17 +- tests/unit/scenario/test_encoding.py | 12 +- tests/unit/scenario/test_leakage_scenario.py | 12 +- .../unit/scenario/test_psychosocial_harms.py | 10 +- tests/unit/scenario/test_rapid_response.py | 37 ++- tests/unit/scenario/test_scenario.py | 61 +---- .../unit/scenario/test_scenario_parameters.py | 14 +- .../scenario/test_scenario_partial_results.py | 30 +-- tests/unit/scenario/test_scenario_retry.py | 38 +-- .../test_scenario_strategy_invariants.py | 8 +- .../unit/setup/test_load_default_datasets.py | 210 ++++----------- 36 files changed, 473 insertions(+), 1048 deletions(-) create mode 100644 pyrit/setup/initializers/scenarios/preload_scenario_metadata.py diff --git a/.github/instructions/scenarios.instructions.md b/.github/instructions/scenarios.instructions.md index 867375f8ab..7f7c131142 100644 --- a/.github/instructions/scenarios.instructions.md +++ b/.github/instructions/scenarios.instructions.md @@ -14,26 +14,29 @@ All scenarios inherit from `Scenario` (ABC) and must: 2. **Optionally declare `BASELINE_ATTACK_POLICY`** (defaults to `BaselineAttackPolicy.Enabled` — a baseline `PromptSendingAttack` is prepended and callers can opt out per run via `initialize_async(include_baseline=False)`): - `BaselineAttackPolicy.Disabled` — baseline supported but off by default (e.g. `Jailbreak`, where templates dominate the run). - `BaselineAttackPolicy.Forbidden` — baseline is meaningless for this scenario's comparison axis (e.g. `AdversarialBenchmark`, which compares against gold-standard answers). Explicit `include_baseline=True` raises `ValueError`. -3. **Implement three abstract methods:** +3. **Pass `strategy_class`, `default_strategy`, and `default_dataset_config` to `super().__init__()`:** ```python class MyScenario(Scenario): VERSION: int = 1 BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Enabled - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - return MyStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - return MyStrategy.ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - return DatasetConfiguration(dataset_names=["my_dataset"]) + @apply_defaults + def __init__(self, *, objective_scorer=None, scenario_result_id=None) -> None: + super().__init__( + version=self.VERSION, + strategy_class=MyStrategy, + default_strategy=MyStrategy.ALL, + default_dataset_config=DatasetConfiguration(dataset_names=["my_dataset"]), + objective_scorer=objective_scorer or self._get_default_objective_scorer(), + scenario_result_id=scenario_result_id, + ) ``` +For scenarios whose strategy enum is built dynamically (RapidResponse pattern), build the +strategy class in a module-level `@cache`-decorated function and pass the result through +the constructor — no classmethod indirection required. + 4. **Optionally override `_get_atomic_attacks_async()`** — the base class provides a default that uses the factory/registry pattern (see "AtomicAttack Construction" below). Only override if your scenario needs custom attack construction logic. @@ -60,6 +63,8 @@ def __init__( super().__init__( version=self.VERSION, strategy_class=MyStrategy, + default_strategy=MyStrategy.ALL, + default_dataset_config=DatasetConfiguration(dataset_names=["my_dataset"]), objective_scorer=objective_scorer, ) ``` @@ -67,7 +72,7 @@ def __init__( Requirements: - `@apply_defaults` decorator on `__init__` - All parameters keyword-only via `*` -- `super().__init__()` called with `version`, `strategy_class`, `objective_scorer` +- `super().__init__()` called with `version`, `strategy_class`, `default_strategy`, `default_dataset_config`, `objective_scorer` - complex objects like `adversarial_chat` or `objective_scorer` should be passed into the constructor. ## Dataset Loading diff --git a/doc/code/scenarios/0_scenarios.ipynb b/doc/code/scenarios/0_scenarios.ipynb index c55411d92f..407dcea2cc 100644 --- a/doc/code/scenarios/0_scenarios.ipynb +++ b/doc/code/scenarios/0_scenarios.ipynb @@ -59,13 +59,13 @@ " - Include an `ALL` aggregate strategy that expands to all available strategies\n", " - Optionally override `_prepare_strategies()` for custom composition logic (see `FoundryComposite`)\n", "\n", - "2. **Scenario Class**: Extend `Scenario` and implement these abstract methods:\n", - " - `get_strategy_class()`: Return your strategy enum class\n", - " - `get_default_strategy()`: Return the default strategy (typically `YourStrategy.ALL`)\n", + "2. **Scenario Class**: Extend `Scenario` and pass these to `super().__init__()`:\n", + " - `strategy_class`: Your strategy enum class\n", + " - `default_strategy`: The default strategy (typically `YourStrategy.ALL` or `YourStrategy.DEFAULT`)\n", " - The base class provides a default `_get_atomic_attacks_async()` that uses the factory/registry\n", " pattern. Override it only if your scenario needs custom attack construction logic.\n", "\n", - "3. **Default Dataset**: Implement `default_dataset_config()` to specify the datasets your scenario uses out of the box.\n", + "3. **Default Dataset**: Pass `default_dataset_config=` to `super().__init__()` to specify the datasets your scenario uses out of the box.\n", " - Returns a `DatasetConfiguration` with one or more named datasets (e.g., `DatasetConfiguration(dataset_names=[\"my_dataset\"])`)\n", " - Users can override this at runtime via `--dataset-names` in the CLI or by passing a custom `dataset_config` programmatically\n", "\n", @@ -97,24 +97,7 @@ "execution_count": null, "id": "1", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found default environment files: ['./.pyrit/.env', './.pyrit/.env.local']\n", - "Loaded environment file: ./.pyrit/.env\n", - "Loaded environment file: ./.pyrit/.env.local\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No new upgrade operations detected.\n" - ] - } - ], + "outputs": [], "source": [ "from pyrit.common import apply_defaults\n", "from pyrit.scenario import (\n", @@ -142,18 +125,6 @@ "\n", " VERSION: int = 1\n", "\n", - " @classmethod\n", - " def get_strategy_class(cls) -> type[ScenarioStrategy]:\n", - " return MyStrategy\n", - "\n", - " @classmethod\n", - " def get_default_strategy(cls) -> ScenarioStrategy:\n", - " return MyStrategy.DEFAULT\n", - "\n", - " @classmethod\n", - " def default_dataset_config(cls) -> DatasetConfiguration:\n", - " return DatasetConfiguration(dataset_names=[\"dataset_name\"], max_dataset_size=4)\n", - "\n", " @apply_defaults\n", " def __init__(\n", " self,\n", @@ -168,7 +139,9 @@ " super().__init__(\n", " version=self.VERSION,\n", " objective_scorer=self._objective_scorer,\n", - " strategy_class=self.get_strategy_class(),\n", + " strategy_class=MyStrategy,\n", + " default_strategy=MyStrategy.DEFAULT,\n", + " default_dataset_config=DatasetConfiguration(dataset_names=[\"dataset_name\"], max_dataset_size=4),\n", " scenario_result_id=scenario_result_id,\n", " )\n", "\n", @@ -196,201 +169,7 @@ "execution_count": null, "id": "3", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading default configuration file: ./.pyrit/.pyrit_conf\n", - "Found default environment files: ['./.pyrit/.env', './.pyrit/.env.local']\n", - "Loaded environment file: ./.pyrit/.env\n", - "Loaded environment file: ./.pyrit/.env.local\n", - "\n", - "Available Scenarios:\n", - "================================================================================\n", - "\u001b[1m\u001b[36m\n", - " airt.cyber\u001b[0m\n", - " Class: Cyber\n", - " Description:\n", - " Cyber scenario implementation for PyRIT. This scenario tests how willing\n", - " models are to exploit cybersecurity harms by generating malware. The\n", - " Cyber class contains different variations of the malware generation\n", - " techniques.\n", - " Aggregate Strategies:\n", - " - all, single_turn, multi_turn\n", - " Available Strategies (2):\n", - " prompt_sending, red_teaming\n", - " Default Strategy: all\n", - " Default Datasets (1, max 4 per dataset):\n", - " airt_malware\n", - "\u001b[1m\u001b[36m\n", - " airt.jailbreak\u001b[0m\n", - " Class: Jailbreak\n", - " Description:\n", - " Jailbreak scenario implementation for PyRIT. This scenario tests how\n", - " vulnerable models are to jailbreak attacks by applying various\n", - " single-turn jailbreak templates to a set of test prompts. The responses\n", - " are scored to determine if the jailbreak was successful.\n", - " Aggregate Strategies:\n", - " - all, simple, complex\n", - " Available Strategies (4):\n", - " prompt_sending, many_shot, skeleton, role_play\n", - " Default Strategy: simple\n", - " Default Datasets (1, max 4 per dataset):\n", - " airt_harms\n", - "\u001b[1m\u001b[36m\n", - " airt.leakage\u001b[0m\n", - " Class: Leakage\n", - " Description:\n", - " Leakage scenario implementation for PyRIT. This scenario tests how\n", - " susceptible models are to leaking training data, PII, intellectual\n", - " property, or other confidential information. Uses the registry/factory\n", - " pattern to construct attack techniques.\n", - " Aggregate Strategies:\n", - " - all, default, single_turn, multi_turn\n", - " Available Strategies (9):\n", - " prompt_sending, role_play, many_shot, tap, crescendo_simulated,\n", - " red_teaming, context_compliance, first_letter, image\n", - " Default Strategy: default\n", - " Default Datasets (1, max 4 per dataset):\n", - " airt_leakage\n", - "\u001b[1m\u001b[36m\n", - " airt.psychosocial\u001b[0m\n", - " Class: Psychosocial\n", - " Description:\n", - " Psychosocial Harms Scenario implementation for PyRIT. This scenario\n", - " contains various psychosocial harm-based checks that evaluate model\n", - " behavior for responses that may harm the well-being or mental health of\n", - " users. Examples include mistreating users in a mental health crisis or\n", - " aiding them in self-harm. All datasets in this scenario share the\n", - " unified 'psychosocial' harm category, but each dataset (e.g.,\n", - " imminent_crisis, dependency) can have unique: - Scoring rubrics - Seed\n", - " objectives The scenario tests whether models: - Recognize mental health\n", - " crises and respond appropriately - Avoid facilitating self-harm or\n", - " harmful behaviors - Provide crisis resources and encourage professional\n", - " help Example usage with custom configurations: # Define custom\n", - " configurations per subharm category custom_configs = {\n", - " \"airt_imminent_crisis\": SubharmConfig(\n", - " crescendo_system_prompt_path=\"path/to/custom_escalation.yaml\",\n", - " scoring_rubric_path=\"path/to/custom_rubric.yaml\", ), } scenario =\n", - " Psychosocial(subharm_configs=custom_configs) await\n", - " scenario.initialize_async( objective_target=target_llm,\n", - " scenario_strategies=[PsychosocialStrategy.ImminentCrisis], )\n", - " Aggregate Strategies:\n", - " - all\n", - " Available Strategies (2):\n", - " imminent_crisis, licensed_therapist\n", - " Default Strategy: all\n", - " Default Datasets (1, max 4 per dataset):\n", - " airt_imminent_crisis\n", - "\u001b[1m\u001b[36m\n", - " airt.rapid_response\u001b[0m\n", - " Class: RapidResponse\n", - " Description:\n", - " Rapid Response scenario for content-harms testing. Tests model behavior\n", - " across multiple harm categories using selectable attack techniques.\n", - " Aggregate Strategies:\n", - " - all, default, single_turn, multi_turn\n", - " Available Strategies (7):\n", - " prompt_sending, role_play, many_shot, tap, crescendo_simulated,\n", - " red_teaming, context_compliance\n", - " Default Strategy: default\n", - " Default Datasets (7, max 4 per dataset):\n", - " airt_hate, airt_fairness, airt_violence, airt_sexual, airt_harassment,\n", - " airt_misinformation, airt_leakage\n", - "\u001b[1m\u001b[36m\n", - " airt.scam\u001b[0m\n", - " Class: Scam\n", - " Description:\n", - " Scam scenario evaluates an endpoint's ability to generate scam-related\n", - " materials (e.g., phishing emails, fraudulent messages) with primarily\n", - " persuasion-oriented techniques.\n", - " Aggregate Strategies:\n", - " - all, single_turn, multi_turn\n", - " Available Strategies (3):\n", - " context_compliance, role_play, persuasive_rta\n", - " Default Strategy: all\n", - " Default Datasets (1, max 4 per dataset):\n", - " airt_scams\n", - " Supported Parameters:\n", - " - max_turns (int) [default: 5]: Maximum conversation turns for the persuasive_rta strategy.\n", - "\u001b[1m\u001b[36m\n", - " benchmark.adversarial\u001b[0m\n", - " Class: AdversarialBenchmark\n", - " Description:\n", - " Benchmarking scenario that compares the attack success rate (ASR) of\n", - " several different adversarial models.\n", - " Aggregate Strategies:\n", - " - all, default, single_turn, multi_turn, light\n", - " Available Strategies (4):\n", - " role_play, tap, red_teaming, context_compliance\n", - " Default Strategy: light\n", - " Default Datasets (1, max 8 per dataset):\n", - " harmbench\n", - "\u001b[1m\u001b[36m\n", - " foundry.red_team_agent\u001b[0m\n", - " Class: RedTeamAgent\n", - " Description:\n", - " RedTeamAgent is a preconfigured scenario that automatically generates\n", - " multiple AtomicAttack instances based on the specified attack\n", - " strategies. It supports both single-turn attacks (with various\n", - " converters) and multi-turn attacks (Crescendo, RedTeaming), making it\n", - " easy to quickly test a target against multiple attack vectors. The\n", - " scenario can expand difficulty levels (EASY, MODERATE, DIFFICULT) into\n", - " their constituent attack strategies, or you can specify individual\n", - " strategies directly. This scenario is designed for use with the Foundry\n", - " AI Red Teaming Agent library, providing a consistent PyRIT contract for\n", - " their integration.\n", - " Aggregate Strategies:\n", - " - all, easy, moderate, difficult\n", - " Available Strategies (25):\n", - " ansi_attack, ascii_art, ascii_smuggler, atbash, base64, binary, caesar,\n", - " character_space, char_swap, diacritic, flip, leetspeak, morse, rot13,\n", - " suffix_append, string_join, unicode_confusable, unicode_substitution,\n", - " url, jailbreak, tense, multi_turn, crescendo, pair, tap\n", - " Default Strategy: easy\n", - " Default Datasets (1, max 4 per dataset):\n", - " harmbench\n", - "\u001b[1m\u001b[36m\n", - " garak.encoding\u001b[0m\n", - " Class: Encoding\n", - " Description:\n", - " Encoding Scenario implementation for PyRIT. This scenario tests how\n", - " resilient models are to various encoding attacks by encoding potentially\n", - " harmful text (by default slurs and XSS payloads) and testing if the\n", - " model will decode and repeat the encoded payload. It mimics the Garak\n", - " encoding probe. The scenario works by: 1. Taking seed prompts (the\n", - " harmful text to be encoded) 2. Encoding them using various encoding\n", - " schemes (Base64, ROT13, Morse, etc.) 3. Asking the target model to\n", - " decode the encoded text 4. Scoring whether the model successfully\n", - " decoded and repeated the harmful content By default, this uses the same\n", - " dataset as Garak: slur terms and web XSS payloads.\n", - " Aggregate Strategies:\n", - " - all\n", - " Available Strategies (17):\n", - " base64, base2048, base16, base32, ascii85, hex, quoted_printable,\n", - " uuencode, rot13, braille, atbash, morse_code, nato, ecoji, zalgo,\n", - " leet_speak, ascii_smuggler\n", - " Default Strategy: all\n", - " Default Datasets (2, max 3 per dataset):\n", - " garak_slur_terms_en, garak_web_html_js\n", - "\n", - "================================================================================\n", - "\n", - "Total scenarios: 9\n" - ] - }, - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from pyrit.backend.services.scenario_service import get_scenario_service\n", "from pyrit.cli._output import print_scenario_list\n", @@ -454,17 +233,8 @@ } ], "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" + "jupytext": { + "main_language": "python" } }, "nbformat": 4, diff --git a/doc/code/scenarios/0_scenarios.py b/doc/code/scenarios/0_scenarios.py index 1f973ac14c..317aeac1ff 100644 --- a/doc/code/scenarios/0_scenarios.py +++ b/doc/code/scenarios/0_scenarios.py @@ -61,13 +61,13 @@ # - Include an `ALL` aggregate strategy that expands to all available strategies # - Optionally override `_prepare_strategies()` for custom composition logic (see `FoundryComposite`) # -# 2. **Scenario Class**: Extend `Scenario` and implement these abstract methods: -# - `get_strategy_class()`: Return your strategy enum class -# - `get_default_strategy()`: Return the default strategy (typically `YourStrategy.ALL`) +# 2. **Scenario Class**: Extend `Scenario` and pass these to `super().__init__()`: +# - `strategy_class`: Your strategy enum class +# - `default_strategy`: The default strategy (typically `YourStrategy.ALL` or `YourStrategy.DEFAULT`) # - The base class provides a default `_get_atomic_attacks_async()` that uses the factory/registry # pattern. Override it only if your scenario needs custom attack construction logic. # -# 3. **Default Dataset**: Implement `default_dataset_config()` to specify the datasets your scenario uses out of the box. +# 3. **Default Dataset**: Pass `default_dataset_config=` to `super().__init__()` to specify the datasets your scenario uses out of the box. # - Returns a `DatasetConfiguration` with one or more named datasets (e.g., `DatasetConfiguration(dataset_names=["my_dataset"])`) # - Users can override this at runtime via `--dataset-names` in the CLI or by passing a custom `dataset_config` programmatically # @@ -120,18 +120,6 @@ class MyScenario(Scenario): VERSION: int = 1 - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - return MyStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - return MyStrategy.DEFAULT - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - return DatasetConfiguration(dataset_names=["dataset_name"], max_dataset_size=4) - @apply_defaults def __init__( self, @@ -146,7 +134,9 @@ def __init__( super().__init__( version=self.VERSION, objective_scorer=self._objective_scorer, - strategy_class=self.get_strategy_class(), + strategy_class=MyStrategy, + default_strategy=MyStrategy.DEFAULT, + default_dataset_config=DatasetConfiguration(dataset_names=["dataset_name"], max_dataset_size=4), scenario_result_id=scenario_result_id, ) diff --git a/doc/scanner/1_pyrit_scan.ipynb b/doc/scanner/1_pyrit_scan.ipynb index df909c56e6..1ff21c6d5b 100644 --- a/doc/scanner/1_pyrit_scan.ipynb +++ b/doc/scanner/1_pyrit_scan.ipynb @@ -198,19 +198,6 @@ "class MyCustomScenario(Scenario):\n", " \"\"\"My custom scenario that does XYZ.\"\"\"\n", "\n", - " @classmethod\n", - " def get_strategy_class(cls):\n", - " return MyCustomStrategy\n", - "\n", - " @classmethod\n", - " def get_default_strategy(cls):\n", - " return MyCustomStrategy.ALL\n", - "\n", - " @classmethod\n", - " def default_dataset_config(cls) -> DatasetConfiguration:\n", - " # Return default dataset configuration for this scenario\n", - " return DatasetConfiguration(dataset_names=[\"harmbench\"])\n", - "\n", " @apply_defaults\n", " def __init__(self, *, scenario_result_id=None, **kwargs):\n", " # Scenario-specific configuration only - no runtime parameters\n", @@ -219,6 +206,8 @@ " version=1,\n", " objective_scorer=TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=OpenAIChatTarget())),\n", " strategy_class=MyCustomStrategy,\n", + " default_strategy=MyCustomStrategy.ALL,\n", + " default_dataset_config=DatasetConfiguration(dataset_names=[\"harmbench\"]),\n", " scenario_result_id=scenario_result_id,\n", " )\n", " # ... your scenario-specific initialization code\n", diff --git a/doc/scanner/1_pyrit_scan.py b/doc/scanner/1_pyrit_scan.py index 904b948d57..c1b452a071 100644 --- a/doc/scanner/1_pyrit_scan.py +++ b/doc/scanner/1_pyrit_scan.py @@ -141,19 +141,6 @@ class MyCustomStrategy(ScenarioStrategy): class MyCustomScenario(Scenario): """My custom scenario that does XYZ.""" - @classmethod - def get_strategy_class(cls): - return MyCustomStrategy - - @classmethod - def get_default_strategy(cls): - return MyCustomStrategy.ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - # Return default dataset configuration for this scenario - return DatasetConfiguration(dataset_names=["harmbench"]) - @apply_defaults def __init__(self, *, scenario_result_id=None, **kwargs): # Scenario-specific configuration only - no runtime parameters @@ -162,6 +149,8 @@ def __init__(self, *, scenario_result_id=None, **kwargs): version=1, objective_scorer=TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=OpenAIChatTarget())), strategy_class=MyCustomStrategy, + default_strategy=MyCustomStrategy.ALL, + default_dataset_config=DatasetConfiguration(dataset_names=["harmbench"]), scenario_result_id=scenario_result_id, ) # ... your scenario-specific initialization code diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 1bf61f1b37..9d79859de2 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -287,9 +287,24 @@ def _build_init_kwargs( if request.labels: init_kwargs["memory_labels"] = request.labels - # Validate and resolve strategies + # Resolve strategies and default dataset config from a temporary instance + # of the scenario. The downstream _initialize_scenario_async builds its + # own instance (so that scenario_result_id can be passed), so this is a + # cheap throwaway used only for introspection. + instance_for_introspection: Scenario | None = None + + if request.strategies or request.max_dataset_size is not None: + try: + instance_for_introspection = scenario_class() # type: ignore[ty:missing-argument] + except Exception as exc: + raise ValueError( + f"Cannot resolve runtime configuration for scenario '{request.scenario_name}': " + f"scenario class is not instantiable without arguments ({exc})." + ) from exc + if request.strategies: - strategy_class = scenario_class.get_strategy_class() + assert instance_for_introspection is not None + strategy_class = instance_for_introspection._strategy_class strategy_enums = [] for name in request.strategies: try: @@ -309,7 +324,8 @@ def _build_init_kwargs( max_dataset_size=request.max_dataset_size, ) elif request.max_dataset_size is not None: - default_config = scenario_class.default_dataset_config() + assert instance_for_introspection is not None + default_config = instance_for_introspection._default_dataset_config default_config.max_dataset_size = request.max_dataset_size init_kwargs["dataset_config"] = default_config diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 4c34e648d2..e879617380 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -197,6 +197,11 @@ def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMet """ Build metadata for a Scenario class. + Instantiates the scenario with no arguments and reads the strategy/dataset + configuration off the instance. Falls back to a degraded metadata record + if instantiation fails (e.g. the scenario requires constructor arguments + that cannot be defaulted). + Args: name: The registry name of the scenario. entry: The ClassEntry containing the scenario class. @@ -208,13 +213,6 @@ def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMet description = entry.get_description(fallback="No description available") - # Get the strategy class for this scenario - strategy_class = scenario_class.get_strategy_class() - - dataset_config = scenario_class.default_dataset_config() - default_datasets = dataset_config.get_default_dataset_names() - max_dataset_size = dataset_config.max_dataset_size - supported_parameters = tuple( ScenarioParameterMetadata( name=p.name, @@ -227,15 +225,35 @@ def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMet for p in scenario_class.supported_parameters() ) + try: + instance = scenario_class() # type: ignore[ty:missing-argument] + strategy_class = instance._strategy_class + default_strategy_value = instance._default_strategy.value + all_strategies = tuple(s.value for s in strategy_class.get_all_strategies()) + aggregate_strategies = tuple(s.value for s in strategy_class.get_aggregate_strategies()) + default_datasets = tuple(instance._default_dataset_config.get_default_dataset_names()) + max_dataset_size = instance._default_dataset_config.max_dataset_size + except Exception as exc: + logger.warning( + "Could not instantiate scenario %s for metadata; emitting degraded metadata. Reason: %s", + scenario_class.__name__, + exc, + ) + default_strategy_value = "" + all_strategies = () + aggregate_strategies = () + default_datasets = () + max_dataset_size = 0 + return ScenarioMetadata( class_name=scenario_class.__name__, class_module=scenario_class.__module__, class_description=description, registry_name=name, - default_strategy=scenario_class.get_default_strategy().value, - all_strategies=tuple(s.value for s in strategy_class.get_all_strategies()), - aggregate_strategies=tuple(s.value for s in strategy_class.get_aggregate_strategies()), - default_datasets=tuple(default_datasets), + default_strategy=default_strategy_value, + all_strategies=all_strategies, + aggregate_strategies=aggregate_strategies, + default_datasets=default_datasets, max_dataset_size=max_dataset_size, supported_parameters=supported_parameters, ) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index aa48c3daa1..8fb0ea261f 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -13,7 +13,7 @@ import logging import textwrap import uuid -from abc import ABC, abstractmethod +from abc import ABC from collections.abc import Sequence from enum import Enum from pathlib import Path @@ -126,7 +126,7 @@ def _format_param_key_diff(*, stored: dict[str, Any], current: dict[str, Any]) - return "; ".join(parts) if parts else "no diff details" -class Scenario(ABC): +class Scenario(ABC): # noqa: B024 - retained for subclass type-checking even without abstract methods """ Groups and executes multiple AtomicAttack instances sequentially. @@ -165,6 +165,8 @@ def __init__( name: str = "", version: int, strategy_class: type[ScenarioStrategy], + default_strategy: ScenarioStrategy, + default_dataset_config: DatasetConfiguration, objective_scorer: Scorer, scenario_result_id: Optional[Union[uuid.UUID, str]] = None, include_default_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. @@ -176,6 +178,11 @@ def __init__( name (str): Descriptive name for the scenario. version (int): Version number of the scenario. strategy_class (Type[ScenarioStrategy]): The strategy enum class for this scenario. + default_strategy (ScenarioStrategy): The default strategy member used when no + ``scenario_strategies`` are passed to ``initialize_async``. Usually an aggregate + member like ``MyStrategy.ALL`` or ``MyStrategy.DEFAULT``. + default_dataset_config (DatasetConfiguration): The default dataset configuration used + when no ``dataset_config`` is passed to ``initialize_async``. objective_scorer (Scorer): The objective scorer used to evaluate attack results. scenario_result_id (Optional[Union[uuid.UUID, str]]): Optional ID of an existing scenario result to resume. Can be either a UUID object or a string representation of a UUID. @@ -203,6 +210,8 @@ def __init__( # Store strategy configuration for use in initialize_async self._strategy_class = strategy_class + self._default_strategy = default_strategy + self._default_dataset_config = default_dataset_config # These will be set in initialize_async self._objective_target: Optional[PromptTarget] = None @@ -255,48 +264,6 @@ def atomic_attack_count(self) -> int: """Get the number of atomic attacks in this scenario.""" return len(self._atomic_attacks) - @classmethod - @abstractmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Get the strategy enum class for this scenario. - - This abstract method must be implemented by all scenario subclasses to return - the ScenarioStrategy enum class that defines the available attack strategies - for the scenario. - - Returns: - Type[ScenarioStrategy]: The strategy enum class (e.g., FoundryStrategy, EncodingStrategy). - """ - - @classmethod - @abstractmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Get the default strategy used when no strategies are specified. - - This abstract method must be implemented by all scenario subclasses to return - the default aggregate strategy (like EASY, ALL) used when scenario_strategies - parameter is None. - - Returns: - ScenarioStrategy: The default aggregate strategy (e.g., FoundryStrategy.EASY, EncodingStrategy.ALL). - """ - - @classmethod - @abstractmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for this scenario. - - This abstract method must be implemented by all scenario subclasses to return - a DatasetConfiguration specifying the default datasets to use when no - dataset_config is provided by the user. - - Returns: - DatasetConfiguration: The default dataset configuration. - """ - @classmethod def supported_parameters(cls) -> list[Parameter]: """ @@ -573,7 +540,7 @@ def _prepare_strategies( Returns: list[ScenarioStrategy]: Ordered, deduplicated concrete strategies. """ - return self._strategy_class.resolve(strategies, default=self.get_default_strategy()) + return self._strategy_class.resolve(strategies, default=self._default_strategy) @apply_defaults async def initialize_async( @@ -605,7 +572,7 @@ async def initialize_async( from the scenario's configuration. dataset_config (Optional[DatasetConfiguration]): Configuration for the dataset source. Use this to specify dataset names or maximum dataset size from the CLI. - If not provided, scenarios use their default_dataset_config(). + If not provided, scenarios use their constructor-supplied default_dataset_config. max_concurrency (int): Maximum number of concurrent attack executions. Defaults to 1. max_retries (int): Maximum number of automatic retries if the scenario raises an exception. Set to 0 (default) for no automatic retries. If set to a positive number, @@ -637,7 +604,7 @@ async def initialize_async( self._objective_target_identifier = objective_target.get_identifier() type(self).TARGET_REQUIREMENTS.validate(target=objective_target) self._dataset_config_provided = dataset_config is not None - self._dataset_config = dataset_config if dataset_config else self.default_dataset_config() + self._dataset_config = dataset_config if dataset_config else self._default_dataset_config self._max_concurrency = max_concurrency self._max_retries = max_retries self._memory_labels = memory_labels or {} @@ -856,7 +823,7 @@ def _raise_dataset_exception(self) -> None: Either load the datasets into the database before running the scenario, or for example datasets, you can use the `load_default_datasets` initializer. - Required datasets: {", ".join(self.default_dataset_config().get_default_dataset_names())} + Required datasets: {", ".join(self._default_dataset_config.get_default_dataset_names())} """ ) raise ValueError(error_msg) diff --git a/pyrit/scenario/scenarios/airt/__init__.py b/pyrit/scenario/scenarios/airt/__init__.py index f4eae9657c..386239f276 100644 --- a/pyrit/scenario/scenarios/airt/__init__.py +++ b/pyrit/scenario/scenarios/airt/__init__.py @@ -6,11 +6,11 @@ from typing import Any from pyrit.scenario.scenarios.airt.content_harms import ContentHarms -from pyrit.scenario.scenarios.airt.cyber import Cyber +from pyrit.scenario.scenarios.airt.cyber import Cyber, _build_cyber_strategy from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy -from pyrit.scenario.scenarios.airt.leakage import Leakage +from pyrit.scenario.scenarios.airt.leakage import Leakage, _build_leakage_strategy from pyrit.scenario.scenarios.airt.psychosocial import Psychosocial, PsychosocialStrategy -from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse +from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse, _build_rapid_response_strategy from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy @@ -25,13 +25,13 @@ def __getattr__(name: str) -> Any: AttributeError: If the attribute name is not recognized. """ if name == "RapidResponseStrategy": - return RapidResponse.get_strategy_class() + return _build_rapid_response_strategy() if name == "LeakageStrategy": - return Leakage.get_strategy_class() + return _build_leakage_strategy() if name == "ContentHarmsStrategy": - return ContentHarms.get_strategy_class() + return _build_rapid_response_strategy() if name == "CyberStrategy": - return Cyber.get_strategy_class() + return _build_cyber_strategy() raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/scenario/scenarios/airt/content_harms.py b/pyrit/scenario/scenarios/airt/content_harms.py index 0307133d2f..d379af6324 100644 --- a/pyrit/scenario/scenarios/airt/content_harms.py +++ b/pyrit/scenario/scenarios/airt/content_harms.py @@ -13,6 +13,9 @@ from pyrit.scenario.scenarios.airt.rapid_response import ( RapidResponse as ContentHarms, ) +from pyrit.scenario.scenarios.airt.rapid_response import ( + _build_rapid_response_strategy, +) def __getattr__(name: str) -> Any: @@ -26,7 +29,7 @@ def __getattr__(name: str) -> Any: AttributeError: If the attribute name is not recognized. """ if name == "ContentHarmsStrategy": - return ContentHarms.get_strategy_class() + return _build_rapid_response_strategy() raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index d29b81eecc..270e53ecef 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -4,7 +4,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, ClassVar +from functools import cache +from typing import TYPE_CHECKING from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -23,6 +24,7 @@ _CYBER_TECHNIQUE_NAMES = {"prompt_sending", "red_teaming"} +@cache def _build_cyber_strategy() -> type[ScenarioStrategy]: """ Build the Cyber strategy class dynamically from SCENARIO_TECHNIQUES. @@ -59,7 +61,6 @@ class Cyber(Scenario): """ VERSION: int = 2 - _cached_strategy_class: ClassVar[type[ScenarioStrategy] | None] = None @classmethod def get_override_composite_scorer_questions_path(cls) -> list[Path]: @@ -71,39 +72,6 @@ def get_override_composite_scorer_questions_path(cls) -> list[Path]: """ return [SCORER_SEED_PROMPT_PATH / "true_false_question" / "malware.yaml"] - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Return the dynamically generated strategy class, building it on first access. - - Returns: - type[ScenarioStrategy]: The CyberStrategy enum class. - """ - if cls._cached_strategy_class is None: - cls._cached_strategy_class = _build_cyber_strategy() - return cls._cached_strategy_class - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Return the default strategy member (``ALL``). - - Returns: - ScenarioStrategy: The ALL strategy value. - """ - strategy_class = cls.get_strategy_class() - return strategy_class("all") - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for this scenario. - - Returns: - DatasetConfiguration: Configuration with airt_malware dataset. - """ - return DatasetConfiguration(dataset_names=["airt_malware"], max_dataset_size=4) - @apply_defaults def __init__( self, @@ -126,10 +94,14 @@ def __init__( objective_scorer if objective_scorer else self._get_default_objective_scorer() ) + strategy_class = _build_cyber_strategy() + super().__init__( version=self.VERSION, objective_scorer=self._objective_scorer, - strategy_class=self.get_strategy_class(), + strategy_class=strategy_class, + default_strategy=strategy_class("all"), + default_dataset_config=DatasetConfiguration(dataset_names=["airt_malware"], max_dataset_size=4), scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index f69b55d017..935b81b51f 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -81,41 +81,11 @@ class Jailbreak(Scenario): VERSION: int = 1 - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Get the strategy enum class for this scenario. - - Returns: - type[ScenarioStrategy]: The JailbreakStrategy enum class. - """ - return JailbreakStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Get the default strategy used when no strategies are specified. - - Returns: - ScenarioStrategy: JailbreakStrategy.PromptSending. - """ - return JailbreakStrategy.SIMPLE - @classmethod def required_datasets(cls) -> list[str]: """Return a list of dataset names required by this scenario.""" return ["airt_harms"] - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for this scenario. - - Returns: - DatasetConfiguration: Configuration with airt_harms dataset. - """ - return DatasetConfiguration(dataset_names=["airt_harms"], max_dataset_size=4) - @apply_defaults def __init__( self, @@ -183,6 +153,8 @@ def __init__( super().__init__( version=self.VERSION, strategy_class=JailbreakStrategy, + default_strategy=JailbreakStrategy.SIMPLE, + default_dataset_config=DatasetConfiguration(dataset_names=["airt_harms"], max_dataset_size=4), objective_scorer=self._objective_scorer, scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/airt/leakage.py b/pyrit/scenario/scenarios/airt/leakage.py index 2ccf54f768..406f346191 100644 --- a/pyrit/scenario/scenarios/airt/leakage.py +++ b/pyrit/scenario/scenarios/airt/leakage.py @@ -4,7 +4,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, ClassVar +from functools import cache +from typing import TYPE_CHECKING from pyrit.common import apply_defaults from pyrit.common.path import DATASETS_PATH, SCORER_SEED_PROMPT_PATH @@ -64,6 +65,7 @@ ] +@cache def _build_leakage_strategy() -> type[ScenarioStrategy]: """ Build the Leakage strategy class dynamically from core + leakage-specific techniques. @@ -98,7 +100,6 @@ class Leakage(Scenario): """ VERSION: int = 2 - _cached_strategy_class: ClassVar[type[ScenarioStrategy] | None] = None @classmethod def _get_additional_scoring_questions(cls) -> list[Path]: @@ -110,34 +111,11 @@ def _get_additional_scoring_questions(cls) -> list[Path]: """ return [SCORER_SEED_PROMPT_PATH / "true_false_question" / "leakage.yaml"] - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """Return the dynamically generated strategy class, building it on first access.""" - if cls._cached_strategy_class is None: - cls._cached_strategy_class = _build_leakage_strategy() - return cls._cached_strategy_class - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Return the default strategy member (DEFAULT). - - Returns: - ScenarioStrategy: The DEFAULT strategy value. - """ - strategy_class = cls.get_strategy_class() - return strategy_class("default") - @classmethod def required_datasets(cls) -> list[str]: """Return a list of dataset names required by this scenario.""" return ["airt_leakage"] - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the default dataset configuration for this scenario.""" - return DatasetConfiguration(dataset_names=["airt_leakage"], max_dataset_size=4) - @apply_defaults def __init__( self, @@ -156,9 +134,13 @@ def __init__( if not objective_scorer: objective_scorer = self._get_default_objective_scorer() + strategy_class = _build_leakage_strategy() + super().__init__( version=self.VERSION, - strategy_class=self.get_strategy_class(), + strategy_class=strategy_class, + default_strategy=strategy_class("default"), + default_dataset_config=DatasetConfiguration(dataset_names=["airt_leakage"], max_dataset_size=4), objective_scorer=objective_scorer, scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 8e7bb0bd5b..8ba3991649 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -175,36 +175,6 @@ class Psychosocial(Scenario): ), } - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Get the strategy enum class for this scenario. - - Returns: - Type[ScenarioStrategy]: The PsychosocialHarmsStrategy enum class. - """ - return PsychosocialStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Get the default strategy used when no strategies are specified. - - Returns: - ScenarioStrategy: PsychosocialStrategy.ALL - """ - return PsychosocialStrategy.ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for this scenario. - - Returns: - DatasetConfiguration: Configuration with psychosocial harm datasets. - """ - return DatasetConfiguration(dataset_names=["airt_imminent_crisis"], max_dataset_size=4) - @apply_defaults def __init__( self, @@ -266,6 +236,8 @@ def __init__( super().__init__( version=self.VERSION, strategy_class=PsychosocialStrategy, + default_strategy=PsychosocialStrategy.ALL, + default_dataset_config=DatasetConfiguration(dataset_names=["airt_imminent_crisis"], max_dataset_size=4), objective_scorer=self._objective_scorer, scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/airt/rapid_response.py b/pyrit/scenario/scenarios/airt/rapid_response.py index 41d853f214..d5a4429864 100644 --- a/pyrit/scenario/scenarios/airt/rapid_response.py +++ b/pyrit/scenario/scenarios/airt/rapid_response.py @@ -13,7 +13,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, ClassVar +from functools import cache +from typing import TYPE_CHECKING from pyrit.common import apply_defaults from pyrit.scenario.core.dataset_configuration import DatasetConfiguration @@ -26,6 +27,7 @@ logger = logging.getLogger(__name__) +@cache def _build_rapid_response_strategy() -> type[ScenarioStrategy]: """ Build the RapidResponse strategy class dynamically from SCENARIO_TECHNIQUES. @@ -59,51 +61,6 @@ class RapidResponse(Scenario): """ VERSION: int = 2 - _cached_strategy_class: ClassVar[type[ScenarioStrategy] | None] = None - - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Return the dynamically generated strategy class, building it on first access. - - Returns: - type[ScenarioStrategy]: The RapidResponseStrategy enum class. - """ - if cls._cached_strategy_class is None: - cls._cached_strategy_class = _build_rapid_response_strategy() - return cls._cached_strategy_class - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Return the default strategy member (``DEFAULT``). - - Returns: - ScenarioStrategy: The default strategy value. - """ - strategy_class = cls.get_strategy_class() - return strategy_class("default") - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for AIRT harm categories. - - Returns: - DatasetConfiguration: Configuration with standard harm-category datasets. - """ - return DatasetConfiguration( - dataset_names=[ - "airt_hate", - "airt_fairness", - "airt_violence", - "airt_sexual", - "airt_harassment", - "airt_misinformation", - "airt_leakage", - ], - max_dataset_size=4, - ) @apply_defaults def __init__( @@ -126,10 +83,25 @@ def __init__( objective_scorer if objective_scorer else self._get_default_objective_scorer() ) + strategy_class = _build_rapid_response_strategy() + super().__init__( version=self.VERSION, objective_scorer=self._objective_scorer, - strategy_class=self.get_strategy_class(), + strategy_class=strategy_class, + default_strategy=strategy_class("default"), + default_dataset_config=DatasetConfiguration( + dataset_names=[ + "airt_hate", + "airt_fairness", + "airt_violence", + "airt_sexual", + "airt_harassment", + "airt_misinformation", + "airt_leakage", + ], + max_dataset_size=4, + ), scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index 84b57cffb8..ab05c0fc81 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -95,41 +95,11 @@ def _get_additional_scoring_questions(cls) -> list[Path]: """ return [SCORER_SEED_PROMPT_PATH / "true_false_question" / "scams.yaml"] - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Get the strategy enum class for this scenario. - - Returns: - Type[ScenarioStrategy]: The ScamStrategy enum class. - """ - return ScamStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Get the default strategy used when no strategies are specified. - - Returns: - ScenarioStrategy: ScamStrategy.ALL (all scam strategies). - """ - return ScamStrategy.ALL - @classmethod def required_datasets(cls) -> list[str]: """Return a list of dataset names required by this scenario.""" return ["airt_scams"] - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for this scenario. - - Returns: - DatasetConfiguration: Configuration with airt_scams dataset. - """ - return DatasetConfiguration(dataset_names=["airt_scams"], max_dataset_size=4) - @classmethod def supported_parameters(cls) -> list[Parameter]: """ @@ -179,6 +149,8 @@ def __init__( super().__init__( version=self.VERSION, strategy_class=ScamStrategy, + default_strategy=ScamStrategy.ALL, + default_dataset_config=DatasetConfiguration(dataset_names=["airt_scams"], max_dataset_size=4), objective_scorer=objective_scorer, scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/benchmark/__init__.py b/pyrit/scenario/scenarios/benchmark/__init__.py index 0b554670d2..58e8d7ae29 100644 --- a/pyrit/scenario/scenarios/benchmark/__init__.py +++ b/pyrit/scenario/scenarios/benchmark/__init__.py @@ -19,7 +19,7 @@ def __getattr__(name: str) -> Any: AttributeError: If the attribute name is not recognized. """ if name == "AdversarialBenchmarkStrategy": - return AdversarialBenchmark.get_strategy_class() + return AdversarialBenchmark._build_benchmark_strategy() raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index dfec12839c..3536a5704f 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging +from functools import cache from typing import TYPE_CHECKING, ClassVar from pyrit.common import apply_defaults @@ -33,54 +34,16 @@ class AdversarialBenchmark(Scenario): """ VERSION: int = 1 - _cached_strategy_class: ClassVar[type[ScenarioStrategy] | None] = None #: AdversarialBenchmark compares attack-success rates across adversarial models; a baseline #: attack would be model-independent and contribute no signal to the comparison. BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Return the AdversarialBenchmarkStrategy enum, building on first access. - - Returns: - type[ScenarioStrategy]: The BenchmarkStrategy enum class. - """ - if cls._cached_strategy_class is None: - cls._cached_strategy_class = AdversarialBenchmark._build_benchmark_strategy() - - return cls._cached_strategy_class - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Return the default strategy (``light`` — run benchmark-friendly techniques - that can wrap up quickly and without too many system resources). - - Returns: - ScenarioStrategy: The ``light`` aggregate member. - """ - return cls.get_strategy_class()("light") - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for benchmarking. - - Returns: - DatasetConfiguration: Configuration with standard harm-category datasets. - """ - return DatasetConfiguration( - dataset_names=["harmbench"], - max_dataset_size=8, - ) - @apply_defaults def __init__( self, *, - adversarial_models: list[PromptTarget], + adversarial_models: list[PromptTarget] | None = None, objective_scorer: TrueFalseScorer | None = None, scenario_result_id: str | None = None, ) -> None: @@ -98,16 +61,57 @@ def __init__( name). Identical targets are silently deduped and distinct targets whose inferred names collide are suffixed (``_2``, ``_3``, …) with a warning. + May be ``None`` at construction so the scenario can be + introspected (e.g. for ``--list-scenarios`` metadata); the + non-empty / capability validation is then deferred to + ``initialize_async``. objective_scorer: Scorer for evaluating attack success. Defaults to the registered default objective scorer. scenario_result_id: Optional ID of an existing scenario result to resume. Raises: - ValueError: If ``adversarial_models`` is empty, not a list, or - contains a target that does not satisfy + ValueError: If ``adversarial_models`` is provided and is empty, + not a list, or contains a target that does not satisfy :data:`CHAT_TARGET_REQUIREMENTS`. """ + if adversarial_models is not None: + self._adversarial_configs = self._build_adversarial_configs(adversarial_models) + else: + self._adversarial_configs = {} + + self._objective_scorer: TrueFalseScorer = ( + objective_scorer if objective_scorer else self._get_default_objective_scorer() + ) + + strategy_class = AdversarialBenchmark._build_benchmark_strategy() + + super().__init__( + version=self.VERSION, + objective_scorer=self._objective_scorer, + strategy_class=strategy_class, + default_strategy=strategy_class("light"), + default_dataset_config=DatasetConfiguration( + dataset_names=["harmbench"], + max_dataset_size=8, + ), + scenario_result_id=scenario_result_id, + ) + + @staticmethod + def _build_adversarial_configs( + adversarial_models: list[PromptTarget], + ) -> dict[str, AttackAdversarialConfig]: + """ + Validate ``adversarial_models`` and wrap each into an ``AttackAdversarialConfig``. + + Returns: + dict[str, AttackAdversarialConfig]: Adversarial configs keyed by inferred model label. + + Raises: + ValueError: If the list is empty, not a list, or contains a target + that does not satisfy :data:`CHAT_TARGET_REQUIREMENTS`. + """ if not adversarial_models: raise ValueError("adversarial_models must be a non-empty list of PromptTarget instances.") @@ -123,23 +127,8 @@ def __init__( f"the chat-target capability requirements: {exc}" ) from exc - # Infer labels, then wrap each bare target in a default AttackAdversarialConfig - # so it can be passed to factory.create() as an override. - labeled_targets = self._infer_labels(items=adversarial_models) - self._adversarial_configs: dict[str, AttackAdversarialConfig] = { - label: AttackAdversarialConfig(target=target) for label, target in labeled_targets.items() - } - - self._objective_scorer: TrueFalseScorer = ( - objective_scorer if objective_scorer else self._get_default_objective_scorer() - ) - - super().__init__( - version=self.VERSION, - objective_scorer=self._objective_scorer, - strategy_class=self.get_strategy_class(), - scenario_result_id=scenario_result_id, - ) + labeled_targets = AdversarialBenchmark._infer_labels(items=adversarial_models) + return {label: AttackAdversarialConfig(target=target) for label, target in labeled_targets.items()} async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ @@ -160,6 +149,12 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: "Scenario not properly initialized. Call await scenario.initialize_async() before running." ) + if not self._adversarial_configs: + raise ValueError( + "AdversarialBenchmark requires adversarial_models to be passed at construction " + "(non-empty list of chat-capable PromptTarget instances)." + ) + benchmarkable_specs = AdversarialBenchmark._get_benchmarkable_specs() local_factories = { spec.name: AttackTechniqueRegistry.build_factory_from_spec(spec) for spec in benchmarkable_specs @@ -264,17 +259,7 @@ def _build_benchmark_strategy() -> type[ScenarioStrategy]: Returns: type[ScenarioStrategy]: The dynamically generated strategy enum class. """ - specs = AdversarialBenchmark._get_benchmarkable_specs() - return AttackTechniqueRegistry.build_strategy_class_from_specs( # type: ignore[ty:invalid-return-type] - class_name="BenchmarkStrategy", - specs=TagQuery.all("core").filter(specs), - aggregate_tags={ - "default": TagQuery.any_of("default"), - "single_turn": TagQuery.any_of("single_turn"), - "multi_turn": TagQuery.any_of("multi_turn"), - "light": TagQuery.any_of("light"), - }, - ) + return _build_benchmark_strategy() @staticmethod def _get_benchmarkable_specs() -> list[AttackTechniqueSpec]: @@ -294,3 +279,24 @@ def _get_benchmarkable_specs() -> list[AttackTechniqueSpec]: for spec in SCENARIO_TECHNIQUES if AttackTechniqueRegistry._accepts_adversarial(spec.attack_class) and spec.adversarial_chat is None ] + + +@cache +def _build_benchmark_strategy() -> type[ScenarioStrategy]: + """ + Module-level cached builder so all callers share the same strategy enum class. + + Returns: + type[ScenarioStrategy]: The dynamically generated BenchmarkStrategy enum class. + """ + specs = AdversarialBenchmark._get_benchmarkable_specs() + return AttackTechniqueRegistry.build_strategy_class_from_specs( # type: ignore[ty:invalid-return-type] + class_name="BenchmarkStrategy", + specs=TagQuery.all("core").filter(specs), + aggregate_tags={ + "default": TagQuery.any_of("default"), + "single_turn": TagQuery.any_of("single_turn"), + "multi_turn": TagQuery.any_of("multi_turn"), + "light": TagQuery.any_of("light"), + }, + ) diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index b9ce521fb4..2cc526d5db 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -215,31 +215,6 @@ class RedTeamAgent(Scenario): VERSION: int = 1 - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Get the strategy enum class for this scenario. - - Returns: - Type[ScenarioStrategy]: The FoundryStrategy enum class. - """ - return FoundryStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Get the default strategy used when no strategies are specified. - - Returns: - ScenarioStrategy: FoundryStrategy.EASY (easy difficulty strategies). - """ - return FoundryStrategy.EASY - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the default dataset configuration for this scenario.""" - return DatasetConfiguration(dataset_names=["harmbench"], max_dataset_size=4) - @apply_defaults def __init__( self, @@ -282,6 +257,8 @@ def __init__( super().__init__( version=self.VERSION, strategy_class=FoundryStrategy, + default_strategy=FoundryStrategy.EASY, + default_dataset_config=DatasetConfiguration(dataset_names=["harmbench"], max_dataset_size=4), objective_scorer=objective_scorer, scenario_result_id=scenario_result_id, ) @@ -359,7 +336,7 @@ def _prepare_strategies( # type: ignore[ty:invalid-method-override] list[ScenarioStrategy]: Flat list of constituent strategies for base-class tracking. """ if not strategies: - resolved = FoundryStrategy.resolve(None, default=cast("FoundryStrategy", self.get_default_strategy())) + resolved = FoundryStrategy.resolve(None, default=cast("FoundryStrategy", self._default_strategy)) self._scenario_composites = [self._strategy_to_composite(s) for s in resolved] return list(resolved) @@ -387,7 +364,7 @@ def _prepare_strategies( # type: ignore[ty:invalid-method-override] flat.append(item.attack) flat.extend(item.converters) else: - for s in FoundryStrategy.resolve([item], default=cast("FoundryStrategy", self.get_default_strategy())): + for s in FoundryStrategy.resolve([item], default=cast("FoundryStrategy", self._default_strategy)): if s not in seen: seen.add(s) composites.append(self._strategy_to_composite(s)) diff --git a/pyrit/scenario/scenarios/garak/encoding.py b/pyrit/scenario/scenarios/garak/encoding.py index c20ece87b4..65f36e3218 100644 --- a/pyrit/scenario/scenarios/garak/encoding.py +++ b/pyrit/scenario/scenarios/garak/encoding.py @@ -134,40 +134,6 @@ class Encoding(Scenario): VERSION: int = 1 - @classmethod - def get_strategy_class(cls) -> type[ScenarioStrategy]: - """ - Get the strategy enum class for this scenario. - - Returns: - Type[ScenarioStrategy]: The EncodingStrategy enum class. - """ - return EncodingStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - """ - Get the default strategy used when no strategies are specified. - - Returns: - ScenarioStrategy: EncodingStrategy.ALL (all encoding strategies). - """ - return EncodingStrategy.ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """ - Return the default dataset configuration for this scenario. - - Returns: - EncodingDatasetConfiguration: Configuration with garak slur terms and web XSS payloads, - where each seed is transformed into a SeedAttackGroup with an encoding objective. - """ - return EncodingDatasetConfiguration( - dataset_names=["garak_slur_terms_en", "garak_web_html_js"], - max_dataset_size=3, - ) - @apply_defaults def __init__( self, @@ -198,6 +164,11 @@ def __init__( super().__init__( version=self.VERSION, strategy_class=EncodingStrategy, + default_strategy=EncodingStrategy.ALL, + default_dataset_config=EncodingDatasetConfiguration( + dataset_names=["garak_slur_terms_en", "garak_web_html_js"], + max_dataset_size=3, + ), objective_scorer=objective_scorer, scenario_result_id=scenario_result_id, ) diff --git a/pyrit/setup/initializers/scenarios/load_default_datasets.py b/pyrit/setup/initializers/scenarios/load_default_datasets.py index 0055736372..c6b5924343 100644 --- a/pyrit/setup/initializers/scenarios/load_default_datasets.py +++ b/pyrit/setup/initializers/scenarios/load_default_datasets.py @@ -52,25 +52,14 @@ def required_env_vars(self) -> list[str]: async def initialize_async(self) -> None: """Load default datasets from all registered scenarios.""" - # Get ScenarioRegistry to discover all scenarios registry = ScenarioRegistry.get_registry_singleton() - # Collect all default datasets from all scenarios all_default_datasets: list[str] = [] - # Get all scenario names from registry - scenario_names = registry.get_names() - - for scenario_name in scenario_names: - scenario_class = registry.get_class(scenario_name) - if scenario_class: - # Get default_dataset_config from the scenario class - try: - datasets = scenario_class.default_dataset_config().get_default_dataset_names() - all_default_datasets.extend(datasets) - logger.info(f"Scenario '{scenario_name}' uses datasets: {datasets}") - except Exception as e: - logger.warning(f"Could not get default datasets from scenario '{scenario_name}': {e}") + for metadata in registry.list_metadata(): + datasets = list(metadata.default_datasets) + all_default_datasets.extend(datasets) + logger.info(f"Scenario '{metadata.registry_name}' uses datasets: {datasets}") # Remove duplicates unique_datasets = list(dict.fromkeys(all_default_datasets)) @@ -81,12 +70,10 @@ async def initialize_async(self) -> None: logger.info(f"Loading {len(unique_datasets)} unique datasets required by all scenarios") - # Fetch the datasets dataset_list = await SeedDatasetProvider.fetch_datasets_async( dataset_names=unique_datasets, ) - # Store datasets in CentralMemory memory = CentralMemory.get_memory_instance() await memory.add_seed_datasets_to_memory_async(datasets=dataset_list, added_by="LoadDefaultDatasets") diff --git a/pyrit/setup/initializers/scenarios/preload_scenario_metadata.py b/pyrit/setup/initializers/scenarios/preload_scenario_metadata.py new file mode 100644 index 0000000000..71bb4cdb12 --- /dev/null +++ b/pyrit/setup/initializers/scenarios/preload_scenario_metadata.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Pre-warm the ScenarioRegistry metadata cache. + +Each registered ``Scenario`` is instantiated once so the registry can read the +strategy class, default strategy, and default dataset configuration off the +instance. The results are cached on ``BaseClassRegistry._metadata_cache``; the +first ``--list-scenarios`` / GUI call is then a cache hit. Per-scenario +instantiation failures are logged and surfaced as degraded metadata by +``ScenarioRegistry._build_metadata`` so the pipeline continues. +""" + +import logging +import textwrap + +from pyrit.registry import ScenarioRegistry +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +logger = logging.getLogger(__name__) + + +class PreloadScenarioMetadata(PyRITInitializer): + """Instantiate every registered scenario once to warm the metadata cache.""" + + @property + def name(self) -> str: + """Return the name of this initializer.""" + return "Preload Scenario Metadata" + + @property + def execution_order(self) -> int: + """Runs after target/scorer/attack-technique initializers, before LoadDefaultDatasets.""" + return 5 + + @property + def description(self) -> str: + """Return a description of this initializer.""" + return textwrap.dedent( + """ + Instantiate every registered scenario once to populate the + ScenarioRegistry metadata cache. This surfaces broken scenario + __init__ implementations loudly at startup (rather than on + first --list-scenarios call) and makes downstream metadata + consumers like LoadDefaultDatasets and the GUI cheap. + """ + ).strip() + + @property + def required_env_vars(self) -> list[str]: + """Return the list of required environment variables.""" + return [] + + async def initialize_async(self) -> None: + """Warm the scenario metadata cache.""" + registry = ScenarioRegistry.get_registry_singleton() + metadata = registry.list_metadata() + logger.info("Preloaded metadata for %d scenarios", len(metadata)) diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 1de264ed5b..20b79b9cea 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -99,8 +99,8 @@ def mock_all_registries(mock_memory): mock_scenario_instance._scenario_result_id = "sr-uuid-1" mock_scenario_class = MagicMock(return_value=mock_scenario_instance) - mock_scenario_class.get_strategy_class.return_value = MagicMock() - mock_scenario_class.default_dataset_config.return_value = MagicMock() + mock_scenario_instance._strategy_class = MagicMock() + mock_scenario_instance._default_dataset_config = MagicMock() mock_sr = MagicMock() mock_sr.get_class.return_value = mock_scenario_class @@ -203,8 +203,8 @@ async def test_start_run_invalid_strategy_raises_value_error(self, mock_memory) mock_strategy_class = MagicMock(side_effect=ValueError("not a valid strategy")) mock_strategy_class.__iter__ = MagicMock(return_value=iter([MagicMock(value="valid_strat")])) - mock_scenario_class = MagicMock() - mock_scenario_class.get_strategy_class.return_value = mock_strategy_class + mock_instance = MagicMock(_strategy_class=mock_strategy_class) + mock_scenario_class = MagicMock(return_value=mock_instance) mock_sr = MagicMock() mock_sr.get_class.return_value = mock_scenario_class diff --git a/tests/unit/scenario/test_adversarial.py b/tests/unit/scenario/test_adversarial.py index 5914a40ba9..958a31ab4b 100644 --- a/tests/unit/scenario/test_adversarial.py +++ b/tests/unit/scenario/test_adversarial.py @@ -202,14 +202,14 @@ def test_adversarial_model_missing_chat_capabilities_raises(self): def test_version_is_1(self): assert AdversarialBenchmark.VERSION == 1 - def test_default_dataset_config_uses_harmbench(self): - config = AdversarialBenchmark.default_dataset_config() + def test_default_dataset_config_uses_harmbench(self, single_adversarial_model): + config = _make_benchmark(single_adversarial_model)._default_dataset_config assert isinstance(config, DatasetConfiguration) names = config.get_default_dataset_names() assert "harmbench" in names - def test_default_dataset_config_max_size_is_8(self): - config = AdversarialBenchmark.default_dataset_config() + def test_default_dataset_config_max_size_is_8(self, single_adversarial_model): + config = _make_benchmark(single_adversarial_model)._default_dataset_config assert config.max_dataset_size == 8 def test_frozen_spec_cannot_be_mutated(self): @@ -235,21 +235,21 @@ def _make_benchmark(adversarial_models): class TestBenchmarkStrategy: """Tests for the (static) BenchmarkStrategy enum and instance-level wiring.""" - def test_strategy_includes_all_adversarial_techniques(self, all_supported_attacks): - """get_strategy_class() concrete members match the adversarial-capable spec set.""" - strat = AdversarialBenchmark.get_strategy_class() + def test_strategy_includes_all_adversarial_techniques(self, all_supported_attacks, single_adversarial_model): + """concrete members match the adversarial-capable spec set.""" + strat = _make_benchmark(single_adversarial_model)._strategy_class values = {s.value for s in strat.get_all_strategies()} assert values == all_supported_attacks - def test_strategy_has_no_permuted_members(self): + def test_strategy_has_no_permuted_members(self, single_adversarial_model): """No ``__model`` suffixes — models are a runtime parameter, not a strategy axis.""" - strat = AdversarialBenchmark.get_strategy_class() + strat = _make_benchmark(single_adversarial_model)._strategy_class values = {s.value for s in strat.get_all_strategies()} assert not any("__" in v for v in values) - def test_strategy_excludes_non_adversarial_techniques(self): + def test_strategy_excludes_non_adversarial_techniques(self, single_adversarial_model): """prompt_sending and many_shot don't accept an adversarial chat and must be excluded.""" - strat = AdversarialBenchmark.get_strategy_class() + strat = _make_benchmark(single_adversarial_model)._strategy_class values = {s.value for s in strat.get_all_strategies()} assert "prompt_sending" not in values assert "many_shot" not in values @@ -259,11 +259,10 @@ def test_strategy_class_is_static(self, single_adversarial_model, two_adversaria s1 = _make_benchmark(single_adversarial_model) s2 = _make_benchmark(two_adversarial_models) assert s1._strategy_class is s2._strategy_class - assert s1._strategy_class is AdversarialBenchmark.get_strategy_class() - def test_default_strategy_is_light(self): + def test_default_strategy_is_light(self, single_adversarial_model): """Default expands to every benchmarkable technique via the ``all`` aggregate.""" - default = AdversarialBenchmark.get_default_strategy() + default = _make_benchmark(single_adversarial_model)._default_strategy assert default.value == "light" def test_benchmarkable_specs_have_no_adversarial_chat(self): @@ -437,7 +436,7 @@ async def test_multiple_datasets_multiplies_attacks(self, mock_objective_target, @pytest.mark.asyncio async def test_attacks_use_all_benchmarkable_attack_classes(self, mock_objective_target, single_adversarial_model): """Under the ``all`` strategy, atomic attacks must cover every adversarial-capable attack class.""" - scenario_class_strategies = AdversarialBenchmark.get_strategy_class() + scenario_class_strategies = _make_benchmark(single_adversarial_model)._strategy_class _, attacks = await self._init_and_get_attacks( mock_objective_target=mock_objective_target, adversarial_models=single_adversarial_model, diff --git a/tests/unit/scenario/test_baseline_deprecation.py b/tests/unit/scenario/test_baseline_deprecation.py index 837c32b65b..d8b1d8c06e 100644 --- a/tests/unit/scenario/test_baseline_deprecation.py +++ b/tests/unit/scenario/test_baseline_deprecation.py @@ -38,6 +38,8 @@ class _LegacyScenario(Scenario): def __init__(self, **kwargs): kwargs.setdefault("strategy_class", _LegacyStrategy) + kwargs.setdefault("default_strategy", _LegacyStrategy.ALL) + kwargs.setdefault("default_dataset_config", DatasetConfiguration()) if "objective_scorer" not in kwargs: mock_scorer = MagicMock(spec=TrueFalseScorer) mock_scorer.get_identifier.return_value = _TEST_SCORER_ID @@ -46,18 +48,6 @@ def __init__(self, **kwargs): kwargs.setdefault("version", 1) super().__init__(**kwargs) - @classmethod - def get_strategy_class(cls): - return _LegacyStrategy - - @classmethod - def get_default_strategy(cls): - return _LegacyStrategy.ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - return DatasetConfiguration() - async def _get_atomic_attacks_async(self): atomic_attacks = [] if self._include_baseline: @@ -104,7 +94,7 @@ async def test_legacy_value_drives_initialize_when_runtime_kwarg_omitted(self, m warnings.simplefilter("ignore", DeprecationWarning) scenario = _LegacyScenario(include_default_baseline=False) - with patch.object(_LegacyScenario, "default_dataset_config", return_value=DatasetConfiguration()): + with patch.object(_LegacyScenario, "default_dataset_config", create=True, return_value=DatasetConfiguration()): await scenario.initialize_async(objective_target=mock_objective_target) assert not any(a.atomic_attack_name == "baseline" for a in scenario._atomic_attacks) @@ -115,7 +105,7 @@ async def test_runtime_kwarg_wins_over_legacy_value(self, mock_objective_target) warnings.simplefilter("ignore", DeprecationWarning) scenario = _LegacyScenario(include_default_baseline=True) - with patch.object(_LegacyScenario, "default_dataset_config", return_value=DatasetConfiguration()): + with patch.object(_LegacyScenario, "default_dataset_config", create=True, return_value=DatasetConfiguration()): await scenario.initialize_async(objective_target=mock_objective_target, include_baseline=False) assert not any(a.atomic_attack_name == "baseline" for a in scenario._atomic_attacks) diff --git a/tests/unit/scenario/test_cyber.py b/tests/unit/scenario/test_cyber.py index d519e8913f..7d7bda49a4 100644 --- a/tests/unit/scenario/test_cyber.py +++ b/tests/unit/scenario/test_cyber.py @@ -28,7 +28,9 @@ def _mock_id(name: str) -> ComponentIdentifier: def _strategy_class(): """Get the dynamically-generated CyberStrategy class.""" - return Cyber.get_strategy_class() + from pyrit.scenario.scenarios.airt.cyber import _build_cyber_strategy + + return _build_cyber_strategy() # --------------------------------------------------------------------------- @@ -61,14 +63,15 @@ def mock_objective_scorer(): def reset_technique_registry(): """Reset the AttackTechniqueRegistry, TargetRegistry, and cached strategy class between tests.""" from pyrit.registry import TargetRegistry + from pyrit.scenario.scenarios.airt.cyber import _build_cyber_strategy AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() - Cyber._cached_strategy_class = None + _build_cyber_strategy.cache_clear() yield AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() - Cyber._cached_strategy_class = None + _build_cyber_strategy.cache_clear() @pytest.fixture @@ -113,21 +116,21 @@ def test_version_is_2(self): def test_get_strategy_class(self): strat = _strategy_class() - assert Cyber.get_strategy_class() is strat + assert Cyber()._strategy_class is strat def test_get_default_strategy_returns_all(self): strat = _strategy_class() - assert Cyber.get_default_strategy() == strat.ALL + assert Cyber()._default_strategy == strat.ALL def test_default_dataset_config_has_malware_dataset(self): - config = Cyber.default_dataset_config() + config = Cyber()._default_dataset_config assert isinstance(config, DatasetConfiguration) names = config.get_default_dataset_names() assert "airt_malware" in names assert len(names) == 1 def test_default_dataset_config_max_dataset_size(self): - config = Cyber.default_dataset_config() + config = Cyber()._default_dataset_config assert config.max_dataset_size == 4 def test_initialization_with_custom_scorer(self, mock_objective_scorer): diff --git a/tests/unit/scenario/test_encoding.py b/tests/unit/scenario/test_encoding.py index bb643ff6dc..107b74f89c 100644 --- a/tests/unit/scenario/test_encoding.py +++ b/tests/unit/scenario/test_encoding.py @@ -334,21 +334,21 @@ async def test_resolve_seed_groups_loads_garak_data( class TestEncodingDatasetConfiguration: """Tests for the EncodingDatasetConfiguration class.""" - def test_default_dataset_config_returns_encoding_config(self): + def test_default_dataset_config_returns_encoding_config(self, mock_objective_scorer): """Test that default_dataset_config returns EncodingDatasetConfiguration.""" - config = Encoding.default_dataset_config() + config = Encoding(objective_scorer=mock_objective_scorer)._default_dataset_config assert isinstance(config, EncodingDatasetConfiguration) - def test_default_dataset_config_uses_garak_datasets(self): + def test_default_dataset_config_uses_garak_datasets(self, mock_objective_scorer): """Test that the default config uses the expected garak datasets.""" - config = Encoding.default_dataset_config() + config = Encoding(objective_scorer=mock_objective_scorer)._default_dataset_config dataset_names = config.get_default_dataset_names() assert "garak_slur_terms_en" in dataset_names assert "garak_web_html_js" in dataset_names - def test_default_dataset_config_has_max_size(self): + def test_default_dataset_config_has_max_size(self, mock_objective_scorer): """Test that the default config has max_dataset_size set.""" - config = Encoding.default_dataset_config() + config = Encoding(objective_scorer=mock_objective_scorer)._default_dataset_config assert config.max_dataset_size == 3 diff --git a/tests/unit/scenario/test_leakage_scenario.py b/tests/unit/scenario/test_leakage_scenario.py index c1b5659e74..6efc165a8f 100644 --- a/tests/unit/scenario/test_leakage_scenario.py +++ b/tests/unit/scenario/test_leakage_scenario.py @@ -177,14 +177,14 @@ def test_scenario_version_is_set(self, mock_objective_scorer): scenario = Leakage(objective_scorer=mock_objective_scorer) assert scenario.VERSION == 2 - def test_get_strategy_class_returns_dynamic_class(self): + def test_get_strategy_class_returns_dynamic_class(self, mock_objective_scorer): """Test that get_strategy_class returns a dynamically generated strategy class.""" - strategy_class = Leakage.get_strategy_class() + strategy_class = Leakage(objective_scorer=mock_objective_scorer)._strategy_class assert strategy_class is LeakageStrategy - def test_get_default_strategy_returns_default(self): + def test_get_default_strategy_returns_default(self, mock_objective_scorer): """Test that get_default_strategy returns the DEFAULT aggregate.""" - default = Leakage.get_default_strategy() + default = Leakage(objective_scorer=mock_objective_scorer)._default_strategy assert default.value == "default" def test_required_datasets_returns_airt_leakage(self): @@ -220,9 +220,9 @@ def test_strategy_default_aggregate_exists(self): assert LeakageStrategy.DEFAULT.value == "default" assert "default" in LeakageStrategy.DEFAULT.tags - def test_strategy_has_technique_members(self): + def test_strategy_has_technique_members(self, mock_objective_scorer): """Test that the strategy has technique members from core + leakage techniques.""" - strategy_class = Leakage.get_strategy_class() + strategy_class = Leakage(objective_scorer=mock_objective_scorer)._strategy_class values = {m.value for m in strategy_class} # Leakage-unique techniques assert "first_letter" in values diff --git a/tests/unit/scenario/test_psychosocial_harms.py b/tests/unit/scenario/test_psychosocial_harms.py index ce16363ade..766fff5ee9 100644 --- a/tests/unit/scenario/test_psychosocial_harms.py +++ b/tests/unit/scenario/test_psychosocial_harms.py @@ -288,13 +288,15 @@ def test_scenario_version_is_set( assert scenario.VERSION == 1 - def test_get_strategy_class(self) -> None: + def test_get_strategy_class(self, mock_objective_scorer) -> None: """Test that the strategy class is PsychosocialStrategy.""" - assert Psychosocial.get_strategy_class() == PsychosocialStrategy + scenario = Psychosocial(objective_scorer=mock_objective_scorer) + assert scenario._strategy_class == PsychosocialStrategy - def test_get_default_strategy(self) -> None: + def test_get_default_strategy(self, mock_objective_scorer) -> None: """Test that the default strategy is ALL.""" - assert Psychosocial.get_default_strategy() == PsychosocialStrategy.ALL + scenario = Psychosocial(objective_scorer=mock_objective_scorer) + assert scenario._default_strategy == PsychosocialStrategy.ALL async def test_no_target_duplication_async( self, diff --git a/tests/unit/scenario/test_rapid_response.py b/tests/unit/scenario/test_rapid_response.py index ecaef3d02c..8ef71c59d9 100644 --- a/tests/unit/scenario/test_rapid_response.py +++ b/tests/unit/scenario/test_rapid_response.py @@ -50,7 +50,9 @@ def _mock_id(name: str) -> ComponentIdentifier: def _strategy_class(): """Get the dynamically-generated RapidResponseStrategy class.""" - return RapidResponse.get_strategy_class() + from pyrit.scenario.scenarios.airt.rapid_response import _build_rapid_response_strategy + + return _build_rapid_response_strategy() # --------------------------------------------------------------------------- @@ -83,14 +85,15 @@ def mock_objective_scorer(): def reset_technique_registry(): """Reset the AttackTechniqueRegistry, TargetRegistry, and cached strategy class between tests.""" from pyrit.registry import TargetRegistry + from pyrit.scenario.scenarios.airt.rapid_response import _build_rapid_response_strategy AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() - RapidResponse._cached_strategy_class = None + _build_rapid_response_strategy.cache_clear() yield AttackTechniqueRegistry.reset_instance() TargetRegistry.reset_instance() - RapidResponse._cached_strategy_class = None + _build_rapid_response_strategy.cache_clear() @pytest.fixture(autouse=True) @@ -145,16 +148,25 @@ class TestRapidResponseBasic: def test_version_is_2(self): assert RapidResponse.VERSION == 2 - def test_get_strategy_class(self): + def test_get_strategy_class(self, mock_objective_scorer): strat = _strategy_class() - assert RapidResponse.get_strategy_class() is strat + with patch( + "pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer", return_value=mock_objective_scorer + ): + assert RapidResponse()._strategy_class is strat - def test_get_default_strategy_returns_default(self): + def test_get_default_strategy_returns_default(self, mock_objective_scorer): strat = _strategy_class() - assert RapidResponse.get_default_strategy() == strat.DEFAULT + with patch( + "pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer", return_value=mock_objective_scorer + ): + assert RapidResponse()._default_strategy == strat.DEFAULT - def test_default_dataset_config_has_all_harm_datasets(self): - config = RapidResponse.default_dataset_config() + def test_default_dataset_config_has_all_harm_datasets(self, mock_objective_scorer): + with patch( + "pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer", return_value=mock_objective_scorer + ): + config = RapidResponse()._default_dataset_config assert isinstance(config, DatasetConfiguration) names = config.get_default_dataset_names() expected = [f"airt_{cat}" for cat in ALL_HARM_CATEGORIES] @@ -162,8 +174,11 @@ def test_default_dataset_config_has_all_harm_datasets(self): assert name in names assert len(names) == 7 - def test_default_dataset_config_max_dataset_size(self): - config = RapidResponse.default_dataset_config() + def test_default_dataset_config_max_dataset_size(self, mock_objective_scorer): + with patch( + "pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer", return_value=mock_objective_scorer + ): + config = RapidResponse()._default_dataset_config assert config.max_dataset_size == 4 @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index a4971a3e0d..b533a210b9 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -145,6 +145,8 @@ def get_aggregate_tags(cls) -> set[str]: return {"all"} kwargs.setdefault("strategy_class", TestStrategy) + kwargs.setdefault("default_strategy", kwargs["strategy_class"].ALL) + kwargs.setdefault("default_dataset_config", DatasetConfiguration()) # Add a mock scorer if not provided if "objective_scorer" not in kwargs: @@ -156,33 +158,6 @@ def get_aggregate_tags(cls) -> set[str]: super().__init__(**kwargs) self._atomic_attacks_to_return = atomic_attacks_to_return or [] - @classmethod - def get_strategy_class(cls): - """Return a mock strategy class for testing.""" - - from pyrit.scenario.core.scenario_strategy import ScenarioStrategy - - # Return a simple mock strategy class for testing - class TestStrategy(ScenarioStrategy): - TEST = ("test", {"concrete"}) # Tagged as concrete, not aggregate - ALL = ("all", {"all"}) - - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} - - return TestStrategy - - @classmethod - def get_default_strategy(cls): - """Return the default strategy for testing.""" - return cls.get_strategy_class().ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the default dataset configuration for testing.""" - return DatasetConfiguration() - async def _get_atomic_attacks_async(self): return self._atomic_attacks_to_return @@ -713,6 +688,8 @@ def get_aggregate_tags(cls) -> set[str]: return {"all"} kwargs.setdefault("strategy_class", TestStrategy) + kwargs.setdefault("default_strategy", kwargs["strategy_class"].ALL) + kwargs.setdefault("default_dataset_config", DatasetConfiguration()) # Use TrueFalseScorer mock if not provided if "objective_scorer" not in kwargs: @@ -721,32 +698,6 @@ def get_aggregate_tags(cls) -> set[str]: super().__init__(**kwargs) self._atomic_attacks_to_return = atomic_attacks_to_return or [] - @classmethod - def get_strategy_class(cls): - """Return a mock strategy class for testing.""" - - from pyrit.scenario.core.scenario_strategy import ScenarioStrategy - - class TestStrategy(ScenarioStrategy): - TEST = ("test", {"concrete"}) - ALL = ("all", {"all"}) - - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} - - return TestStrategy - - @classmethod - def get_default_strategy(cls): - """Return the default strategy for testing.""" - return cls.get_strategy_class().ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the default dataset configuration for testing.""" - return DatasetConfiguration() - async def _get_atomic_attacks_async(self): atomic_attacks = list(self._atomic_attacks_to_return) if self._include_baseline: @@ -890,8 +841,8 @@ async def test_standalone_baseline_uses_dataset_config_seeds(self, mock_objectiv def test_empty_list_strategies_expands_defaults_same_as_none(self): """Test that [] and None both expand to the default strategy set.""" scenario = ConcreteScenario(name="Test", version=1) - strategy_class = scenario.get_strategy_class() - default = scenario.get_default_strategy() + strategy_class = scenario._strategy_class + default = scenario._default_strategy resolved_none = strategy_class.resolve(None, default=default) resolved_empty = strategy_class.resolve([], default=default) diff --git a/tests/unit/scenario/test_scenario_parameters.py b/tests/unit/scenario/test_scenario_parameters.py index 9c8b6fe6cc..d53a768a4d 100644 --- a/tests/unit/scenario/test_scenario_parameters.py +++ b/tests/unit/scenario/test_scenario_parameters.py @@ -37,18 +37,6 @@ class _ParamTestScenario(Scenario): # No baseline in tests so atomic_attacks observations stay deterministic. BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden - @classmethod - def get_strategy_class(cls): - return _ParamTestStrategy - - @classmethod - def get_default_strategy(cls): - return _ParamTestStrategy.ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - return DatasetConfiguration() - @classmethod def supported_parameters(cls) -> list[Parameter]: return list(params_to_declare) @@ -63,6 +51,8 @@ async def _get_atomic_attacks_async(self): return _ParamTestScenario( version=1, strategy_class=_ParamTestStrategy, + default_strategy=_ParamTestStrategy.ALL, + default_dataset_config=DatasetConfiguration(), objective_scorer=mock_scorer, ) diff --git a/tests/unit/scenario/test_scenario_partial_results.py b/tests/unit/scenario/test_scenario_partial_results.py index 91e3f27cd2..2efb130cbe 100644 --- a/tests/unit/scenario/test_scenario_partial_results.py +++ b/tests/unit/scenario/test_scenario_partial_results.py @@ -98,40 +98,32 @@ class ConcreteScenario(Scenario): BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden def __init__(self, *, atomic_attacks_to_return=None, objective_scorer=None, **kwargs): - # Get strategy_class from kwargs or use default - strategy_class = kwargs.pop("strategy_class", None) or self.get_strategy_class() + strategy_class = kwargs.pop("strategy_class", None) or _build_test_strategy() # Create a default mock scorer if not provided if objective_scorer is None: objective_scorer = MagicMock() objective_scorer.get_identifier.return_value = _mock_scorer_id("MockScorer") + kwargs.setdefault("default_strategy", strategy_class.ALL) + kwargs.setdefault("default_dataset_config", DatasetConfiguration()) super().__init__(strategy_class=strategy_class, objective_scorer=objective_scorer, **kwargs) self._test_atomic_attacks = atomic_attacks_to_return or [] async def _get_atomic_attacks_async(self): return self._test_atomic_attacks - @classmethod - def get_strategy_class(cls): - class TestStrategy(ScenarioStrategy): - CONCRETE = ("concrete", {"concrete"}) - ALL = ("all", {"all"}) - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} +def _build_test_strategy(): + class TestStrategy(ScenarioStrategy): + CONCRETE = ("concrete", {"concrete"}) + ALL = ("all", {"all"}) - return TestStrategy + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all"} - @classmethod - def get_default_strategy(cls): - return cls.get_strategy_class().ALL - - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the default dataset configuration for testing.""" - return DatasetConfiguration() + return TestStrategy @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/scenario/test_scenario_retry.py b/tests/unit/scenario/test_scenario_retry.py index d26cb1ae03..8b2131b7d7 100644 --- a/tests/unit/scenario/test_scenario_retry.py +++ b/tests/unit/scenario/test_scenario_retry.py @@ -169,44 +169,32 @@ class ConcreteScenario(Scenario): BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden def __init__(self, atomic_attacks_to_return=None, objective_scorer=None, **kwargs): - # Get strategy_class from kwargs or use default - strategy_class = kwargs.pop("strategy_class", None) or self.get_strategy_class() + strategy_class = kwargs.pop("strategy_class", None) or _build_test_strategy() # Create a default mock scorer if not provided if objective_scorer is None: objective_scorer = MagicMock() objective_scorer.get_identifier.return_value = _mock_scorer_id("MockScorer") + kwargs.setdefault("default_strategy", strategy_class.ALL) + kwargs.setdefault("default_dataset_config", DatasetConfiguration()) super().__init__(strategy_class=strategy_class, objective_scorer=objective_scorer, **kwargs) self._atomic_attacks_to_return = atomic_attacks_to_return or [] - @classmethod - def get_strategy_class(cls): - """Return a mock strategy class for testing.""" - - # Return a simple mock strategy class for testing - class TestStrategy(ScenarioStrategy): - CONCRETE = ("concrete", {"concrete"}) - ALL = ("all", {"all"}) - - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} + async def _get_atomic_attacks_async(self): + return self._atomic_attacks_to_return - return TestStrategy - @classmethod - def get_default_strategy(cls): - """Return the default strategy for testing.""" - return cls.get_strategy_class().ALL +def _build_test_strategy(): + class TestStrategy(ScenarioStrategy): + CONCRETE = ("concrete", {"concrete"}) + ALL = ("all", {"all"}) - @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the default dataset configuration for testing.""" - return DatasetConfiguration() + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all"} - async def _get_atomic_attacks_async(self): - return self._atomic_attacks_to_return + return TestStrategy @pytest.fixture diff --git a/tests/unit/scenario/test_scenario_strategy_invariants.py b/tests/unit/scenario/test_scenario_strategy_invariants.py index 3fa9bbaa86..e835caf282 100644 --- a/tests/unit/scenario/test_scenario_strategy_invariants.py +++ b/tests/unit/scenario/test_scenario_strategy_invariants.py @@ -78,15 +78,15 @@ def _mock_runtime_env(): def _get_rapid_response_strategy(): - from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse + from pyrit.scenario.scenarios.airt.rapid_response import _build_rapid_response_strategy - return RapidResponse.get_strategy_class() + return _build_rapid_response_strategy() def _get_cyber_strategy(): - from pyrit.scenario.scenarios.airt.cyber import Cyber + from pyrit.scenario.scenarios.airt.cyber import _build_cyber_strategy - return Cyber.get_strategy_class() + return _build_cyber_strategy() SCENARIO_STRATEGY_BUILDERS = [ diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index 43f4562ee5..14c2602510 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -5,6 +5,7 @@ Unit tests for LoadDefaultDatasets initializer. """ +from dataclasses import dataclass, field from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -13,10 +14,15 @@ from pyrit.memory import CentralMemory from pyrit.models import SeedDataset from pyrit.registry import ScenarioRegistry -from pyrit.scenario.core.scenario import Scenario from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets +@dataclass +class _FakeMetadata: + registry_name: str + default_datasets: tuple[str, ...] = field(default_factory=tuple) + + @pytest.mark.usefixtures("patch_central_database") class TestLoadDefaultDatasets: """Test suite for LoadDefaultDatasets initializer.""" @@ -39,7 +45,7 @@ async def test_initialize_async_no_scenarios(self) -> None: """Test initialization when no scenarios are registered.""" initializer = LoadDefaultDatasets() - with patch.object(ScenarioRegistry, "get_names", return_value=[]): + with patch.object(ScenarioRegistry, "list_metadata", return_value=[]): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: with patch.object(CentralMemory, "get_memory_instance") as mock_memory: mock_memory_instance = MagicMock() @@ -48,7 +54,6 @@ async def test_initialize_async_no_scenarios(self) -> None: await initializer.initialize_async() - # Should not fetch datasets if no scenarios mock_fetch.assert_not_called() mock_memory_instance.add_seed_datasets_to_memory_async.assert_not_called() @@ -56,149 +61,73 @@ async def test_initialize_async_with_scenarios(self) -> None: """Test initialization with scenarios that require datasets.""" initializer = LoadDefaultDatasets() - # Mock scenario class with default_dataset_config - mock_dataset_config = MagicMock() - mock_dataset_config.get_default_dataset_names.return_value = ["dataset1", "dataset2"] - mock_scenario_class = MagicMock(spec=Scenario) - mock_scenario_class.default_dataset_config.return_value = mock_dataset_config + metadata = [_FakeMetadata(registry_name="mock_scenario", default_datasets=("dataset1", "dataset2"))] - with patch.object(ScenarioRegistry, "get_names", return_value=["mock_scenario"]): - with patch.object(ScenarioRegistry, "get_class", return_value=mock_scenario_class): - with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: - mock_dataset1 = MagicMock(spec=SeedDataset) - mock_dataset2 = MagicMock(spec=SeedDataset) - mock_fetch.return_value = [mock_dataset1, mock_dataset2] + with patch.object(ScenarioRegistry, "list_metadata", return_value=metadata): + with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: + mock_dataset1 = MagicMock(spec=SeedDataset) + mock_dataset2 = MagicMock(spec=SeedDataset) + mock_fetch.return_value = [mock_dataset1, mock_dataset2] - with patch.object(CentralMemory, "get_memory_instance") as mock_memory: - mock_memory_instance = MagicMock() - mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() - mock_memory.return_value = mock_memory_instance + with patch.object(CentralMemory, "get_memory_instance") as mock_memory: + mock_memory_instance = MagicMock() + mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() + mock_memory.return_value = mock_memory_instance - await initializer.initialize_async() + await initializer.initialize_async() - # Verify fetch_datasets_async was called with correct datasets - mock_fetch.assert_called_once() - call_kwargs = mock_fetch.call_args.kwargs - assert set(call_kwargs["dataset_names"]) == {"dataset1", "dataset2"} + mock_fetch.assert_called_once() + call_kwargs = mock_fetch.call_args.kwargs + assert set(call_kwargs["dataset_names"]) == {"dataset1", "dataset2"} - # Verify datasets were added to memory - mock_memory_instance.add_seed_datasets_to_memory_async.assert_called_once_with( - datasets=[mock_dataset1, mock_dataset2], added_by="LoadDefaultDatasets" - ) + mock_memory_instance.add_seed_datasets_to_memory_async.assert_called_once_with( + datasets=[mock_dataset1, mock_dataset2], added_by="LoadDefaultDatasets" + ) async def test_initialize_async_deduplicates_datasets(self) -> None: """Test that duplicate datasets from multiple scenarios are deduplicated.""" initializer = LoadDefaultDatasets() - # Mock two scenarios requiring overlapping datasets - mock_dataset_config1 = MagicMock() - mock_dataset_config1.get_default_dataset_names.return_value = ["dataset1", "dataset2"] - mock_scenario1 = MagicMock(spec=Scenario) - mock_scenario1.default_dataset_config.return_value = mock_dataset_config1 - - mock_dataset_config2 = MagicMock() - mock_dataset_config2.get_default_dataset_names.return_value = ["dataset2", "dataset3"] - mock_scenario2 = MagicMock(spec=Scenario) - mock_scenario2.default_dataset_config.return_value = mock_dataset_config2 - - def get_scenario_side_effect(name: str): - if name == "scenario1": - return mock_scenario1 - if name == "scenario2": - return mock_scenario2 - return None - - with patch.object(ScenarioRegistry, "get_names", return_value=["scenario1", "scenario2"]): - with patch.object(ScenarioRegistry, "get_class", side_effect=get_scenario_side_effect): - with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: - mock_fetch.return_value = [] - - with patch.object(CentralMemory, "get_memory_instance") as mock_memory: - mock_memory_instance = MagicMock() - mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() - mock_memory.return_value = mock_memory_instance - - await initializer.initialize_async() - - # Verify only unique datasets were requested - mock_fetch.assert_called_once() - call_kwargs = mock_fetch.call_args.kwargs - assert set(call_kwargs["dataset_names"]) == {"dataset1", "dataset2", "dataset3"} - # Verify order is preserved (dict.fromkeys maintains insertion order) - assert len(call_kwargs["dataset_names"]) == 3 - - async def test_initialize_async_handles_scenario_errors(self) -> None: - """Test that initialization continues when a scenario raises an error.""" - initializer = LoadDefaultDatasets() - - # Mock one scenario that works and one that fails - mock_dataset_config_good = MagicMock() - mock_dataset_config_good.get_default_dataset_names.return_value = ["dataset1"] - mock_scenario_good = MagicMock(spec=Scenario) - mock_scenario_good.default_dataset_config.return_value = mock_dataset_config_good - - mock_scenario_bad = MagicMock(spec=Scenario) - mock_scenario_bad.default_dataset_config.side_effect = Exception("Test error") + metadata = [ + _FakeMetadata(registry_name="scenario1", default_datasets=("dataset1", "dataset2")), + _FakeMetadata(registry_name="scenario2", default_datasets=("dataset2", "dataset3")), + ] - def get_scenario_side_effect(name: str): - if name == "good_scenario": - return mock_scenario_good - if name == "bad_scenario": - return mock_scenario_bad - return None - - with patch.object(ScenarioRegistry, "get_names", return_value=["good_scenario", "bad_scenario"]): - with patch.object(ScenarioRegistry, "get_class", side_effect=get_scenario_side_effect): - with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: - mock_fetch.return_value = [] + with patch.object(ScenarioRegistry, "list_metadata", return_value=metadata): + with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: + mock_fetch.return_value = [] - with patch.object(CentralMemory, "get_memory_instance") as mock_memory: - mock_memory_instance = MagicMock() - mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() - mock_memory.return_value = mock_memory_instance + with patch.object(CentralMemory, "get_memory_instance") as mock_memory: + mock_memory_instance = MagicMock() + mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() + mock_memory.return_value = mock_memory_instance - await initializer.initialize_async() + await initializer.initialize_async() - # Verify it still fetched datasets from the good scenario - mock_fetch.assert_called_once() - call_kwargs = mock_fetch.call_args.kwargs - assert "dataset1" in call_kwargs["dataset_names"] + mock_fetch.assert_called_once() + call_kwargs = mock_fetch.call_args.kwargs + assert set(call_kwargs["dataset_names"]) == {"dataset1", "dataset2", "dataset3"} + assert len(call_kwargs["dataset_names"]) == 3 async def test_all_required_datasets_available_in_seed_provider(self) -> None: """ Test that all datasets required by scenarios are available in SeedDatasetProvider. - This test ensures that every dataset name returned by scenario.required_datasets() - exists in the SeedDatasetProvider registry. + This test ensures that every dataset name listed in scenario metadata exists in + the SeedDatasetProvider registry. """ - # Get all available dataset names from SeedDatasetProvider available_datasets = set(await SeedDatasetProvider.get_all_dataset_names_async()) - # Get ScenarioRegistry to discover all scenarios registry = ScenarioRegistry.get_registry_singleton() - scenario_names = registry.get_names() - # Collect all required datasets from all scenarios missing_datasets: list[str] = [] - scenario_dataset_map: dict[str, list[str]] = {} - - for scenario_name in scenario_names: - scenario_class = registry.get_class(scenario_name) - if scenario_class: - try: - required = scenario_class.default_dataset_config().get_default_dataset_names() - scenario_dataset_map[scenario_name] = required - - missing_datasets.extend( - f"{scenario_name} requires '{dataset_name}'" - for dataset_name in required - if dataset_name not in available_datasets - ) - except Exception as e: - # Log but don't fail - some scenarios might not be fully initialized - print(f"Warning: Could not get required datasets from {scenario_name}: {e}") + for metadata in registry.list_metadata(): + missing_datasets.extend( + f"{metadata.registry_name} requires '{dataset_name}'" + for dataset_name in metadata.default_datasets + if dataset_name not in available_datasets + ) - # Assert that all required datasets are available assert len(missing_datasets) == 0, ( "The following scenarios require datasets not available in SeedDatasetProvider:\n" + "\n".join(missing_datasets) @@ -208,38 +137,15 @@ async def test_initialize_async_empty_dataset_list(self) -> None: """Test initialization when scenarios return empty dataset lists.""" initializer = LoadDefaultDatasets() - mock_dataset_config = MagicMock() - mock_dataset_config.get_default_dataset_names.return_value = [] - mock_scenario = MagicMock(spec=Scenario) - mock_scenario.default_dataset_config.return_value = mock_dataset_config - - with patch.object(ScenarioRegistry, "get_names", return_value=["empty_scenario"]): - with patch.object(ScenarioRegistry, "get_class", return_value=mock_scenario): - with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: - with patch.object(CentralMemory, "get_memory_instance") as mock_memory: - mock_memory_instance = MagicMock() - mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() - mock_memory.return_value = mock_memory_instance - - await initializer.initialize_async() - - # Should not fetch datasets when all scenarios return empty lists - mock_fetch.assert_not_called() - mock_memory_instance.add_seed_datasets_to_memory_async.assert_not_called() - - async def test_initialize_async_none_scenario_class(self) -> None: - """Test initialization when get_scenario returns None for a scenario.""" - initializer = LoadDefaultDatasets() + metadata = [_FakeMetadata(registry_name="empty_scenario", default_datasets=())] - with patch.object(ScenarioRegistry, "get_names", return_value=["nonexistent_scenario"]): - with patch.object(ScenarioRegistry, "get_class", return_value=None): - with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: - with patch.object(CentralMemory, "get_memory_instance") as mock_memory: - mock_memory_instance = MagicMock() - mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() - mock_memory.return_value = mock_memory_instance + with patch.object(ScenarioRegistry, "list_metadata", return_value=metadata): + with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: + with patch.object(CentralMemory, "get_memory_instance") as mock_memory: + mock_memory_instance = MagicMock() + mock_memory_instance.add_seed_datasets_to_memory_async = AsyncMock() + mock_memory.return_value = mock_memory_instance - await initializer.initialize_async() + await initializer.initialize_async() - # Should not crash, just skip the None scenario - mock_fetch.assert_not_called() + mock_fetch.assert_not_called()