From b3bd47050f7ea54a0febb38575de841ebba5fd2e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 22 May 2026 17:40:30 -0700 Subject: [PATCH] Run atomic attacks in parallel within a scenario Allow multiple atomic attacks in a scenario to run concurrently, driven by the existing max_concurrency parameter. All in-flight objectives across all atomic attacks share a single asyncio.Semaphore(max_concurrency), so the global concurrent-objective budget is bounded by max_concurrency regardless of how work is distributed across atomic attacks. A long-running attack can elastically use slots freed by short-running siblings. Changes: - AttackExecutor now accepts an optional external semaphore kwarg. When provided, it gates both seed-group parameter building and per-objective execution, letting a parent (e.g. Scenario) share one budget across many executors. - AtomicAttack.run_async forwards the optional semaphore to its executor. - Scenario._execute_scenario_async: when max_concurrency > 1 and more than one atomic attack remains, creates one shared semaphore and launches every remaining atomic attack via asyncio.gather, all sharing that semaphore. When max_concurrency == 1 (or only one attack remains), keeps the existing sequential loop verbatim, preserving abort-on-first-failure semantics. - Parallel failure mode uses gather(return_exceptions=True) so in-flight siblings finish before the first error is re-raised (preserves partial work for resume). - No new user-facing parameters or CLI flags. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/executor/attack/core/attack_executor.py | 20 +- pyrit/scenario/core/atomic_attack.py | 12 +- pyrit/scenario/core/scenario.py | 279 ++++++++++++------ tests/unit/scenario/test_atomic_attack.py | 4 +- tests/unit/scenario/test_scenario.py | 236 ++++++++++++++- tests/unit/scenario/test_scenario_retry.py | 3 + 6 files changed, 448 insertions(+), 106 deletions(-) diff --git a/pyrit/executor/attack/core/attack_executor.py b/pyrit/executor/attack/core/attack_executor.py index 6405b6816..8a72bb455 100644 --- a/pyrit/executor/attack/core/attack_executor.py +++ b/pyrit/executor/attack/core/attack_executor.py @@ -120,12 +120,19 @@ class AttackExecutor: from seed groups. """ - def __init__(self, *, max_concurrency: int = 1) -> None: + def __init__(self, *, max_concurrency: int = 1, semaphore: Optional[asyncio.Semaphore] = None) -> None: """ Initialize the attack executor with configurable concurrency control. Args: max_concurrency: Maximum number of concurrent attack executions (default: 1). + Ignored when ``semaphore`` is provided. + semaphore: Optional externally-owned ``asyncio.Semaphore`` used to gate + concurrent objective execution. When provided, all concurrency control + (both seed-group parameter building and attack execution) is delegated + to this semaphore, allowing a parent (e.g., a Scenario) to share a + single budget across many executors. When ``None`` (default), a new + semaphore is created from ``max_concurrency``. Raises: ValueError: If max_concurrency is not a positive integer. @@ -133,6 +140,13 @@ def __init__(self, *, max_concurrency: int = 1) -> None: if max_concurrency <= 0: raise ValueError(f"max_concurrency must be a positive integer, got {max_concurrency}") self._max_concurrency = max_concurrency + self._external_semaphore = semaphore + + def _get_semaphore(self) -> asyncio.Semaphore: + """Return the externally-supplied semaphore, or a fresh one sized to max_concurrency.""" + if self._external_semaphore is not None: + return self._external_semaphore + return asyncio.Semaphore(self._max_concurrency) async def execute_attack_from_seed_groups_async( self, @@ -193,7 +207,7 @@ async def execute_attack_from_seed_groups_async( # Build params list using from_seed_group_async with concurrency control # This can take time if the SeedSimulatedConversation generation is included - semaphore = asyncio.Semaphore(self._max_concurrency) + semaphore = self._get_semaphore() async def build_params(i: int, sg: SeedAttackGroup) -> AttackParameters: async with semaphore: @@ -309,7 +323,7 @@ async def _execute_with_params_list_async( Returns: AttackExecutorResult with completed results and any incomplete objectives. """ - semaphore = asyncio.Semaphore(self._max_concurrency) + semaphore = self._get_semaphore() async def run_one(index: int, params: AttackParameters) -> AttackStrategyResultT: async with semaphore: diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index 1138671f3..48ef7c4a4 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -13,6 +13,7 @@ have a common interface for scenarios. """ +import asyncio import logging from typing import TYPE_CHECKING, Any, Optional @@ -303,6 +304,7 @@ async def run_async( *, max_concurrency: int = 1, return_partial_on_failure: bool = True, + semaphore: asyncio.Semaphore | None = None, **attack_params: Any, ) -> AttackExecutorResult[AttackResult]: """ @@ -321,10 +323,16 @@ async def run_async( Args: max_concurrency (int): Maximum number of concurrent attack executions. - Defaults to 1 for sequential execution. + Defaults to 1 for sequential execution. Ignored when ``semaphore`` + is provided. return_partial_on_failure (bool): If True, returns partial results even when some objectives don't complete execution. If False, raises an exception on any execution failure. Defaults to True. + semaphore (asyncio.Semaphore | None): Optional externally-owned semaphore + used to gate objective execution. Allows a parent (e.g., a Scenario + running multiple atomic attacks in parallel) to share a single + concurrency budget across all of them. When provided, takes precedence + over ``max_concurrency``. **attack_params: Additional parameters to pass to the attack strategy. Returns: @@ -334,7 +342,7 @@ async def run_async( Raises: ValueError: If the attack execution fails completely and return_partial_on_failure=False. """ - executor = AttackExecutor(max_concurrency=max_concurrency) + executor = AttackExecutor(max_concurrency=max_concurrency, semaphore=semaphore) logger.info( f"Starting atomic attack execution with {len(self._seed_groups)} seed groups " diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index aa48c3daa..32a86aaa2 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -8,6 +8,7 @@ AtomicAttack instances sequentially, enabling comprehensive security testing campaigns. """ +import asyncio import copy import json import logging @@ -607,6 +608,9 @@ async def initialize_async( Use this to specify dataset names or maximum dataset size from the CLI. If not provided, scenarios use their default_dataset_config(). max_concurrency (int): Maximum number of concurrent attack executions. Defaults to 1. + This is the total in-flight budget for the scenario, split internally across + concurrent atomic attacks so that the cross-atomic and intra-atomic concurrency + together stay approximately at this value. max_retries (int): Maximum number of automatic retries if the scenario raises an exception. Set to 0 (default) for no automatic retries. If set to a positive number, the scenario will automatically retry up to this many times after an exception. @@ -1230,97 +1234,23 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Calculate starting index based on completed attacks completed_count = len(self._atomic_attacks) - len(remaining_attacks) + # When max_concurrency == 1, run atomic attacks serially with abort-on-first-failure + # semantics. Otherwise, run them all in parallel sharing a single objective-level + # Semaphore(max_concurrency) so the global in-flight objective budget never exceeds + # max_concurrency regardless of how the work is distributed across atomic attacks. try: - for i, atomic_attack in enumerate( - tqdm( - remaining_attacks, - desc=f"Executing {self._name}", - unit="attack", - total=len(self._atomic_attacks), - initial=completed_count, - ), - start=completed_count + 1, - ): - # Stamp the scenario id onto the atomic attack so each persisted - # AttackResult carries the attribution_parent_id linkage. This - # is what enables mid-run interruption recovery (results are - # visible without the post-atomic-attack bulk manifest write). - atomic_attack.set_scenario_result_id(scenario_result_id) - - logger.info( - f"Executing atomic attack {i}/{len(self._atomic_attacks)} " - f"('{atomic_attack.atomic_attack_name}') in scenario '{self._name}'" + if self._max_concurrency <= 1: + await self._execute_atomic_attacks_sequential_async( + remaining_attacks=remaining_attacks, + scenario_result_id=scenario_result_id, + completed_count=completed_count, + ) + else: + await self._execute_atomic_attacks_parallel_async( + remaining_attacks=remaining_attacks, + scenario_result_id=scenario_result_id, + completed_count=completed_count, ) - - try: - atomic_results = await atomic_attack.run_async( - max_concurrency=self._max_concurrency, - return_partial_on_failure=True, - ) - - # Per-result scenario linkage is now stamped by the attack - # event handler at write time; no post-atomic bulk update. - - # Check if there were any incomplete objectives - if atomic_results.has_incomplete: - incomplete_count = len(atomic_results.incomplete_objectives) - completed_count = len(atomic_results.completed_results) - - logger.error( - f"Atomic attack {i}/{len(self._atomic_attacks)} " - f"('{atomic_attack.atomic_attack_name}') partially completed: " - f"{completed_count} completed, {incomplete_count} incomplete" - ) - - # Log details of each incomplete objective - for obj, exc in atomic_results.incomplete_objectives: - logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") - - # Error AttackResults are linked to this scenario via the - # attribution_parent_id foreign key on AttackResultEntry - # (stamped by the attack event handler when an - # AttackResultAttribution is on the context). The - # previous per-scenario error_id manifest is no longer - # needed. - - # Mark scenario as failed - error_msg = ( - f"Atomic attack '{atomic_attack.atomic_attack_name}' partially failed: " - f"{incomplete_count} of {incomplete_count + completed_count} objectives incomplete. " - f"See attack results for details." - ) - self._memory.update_scenario_run_state( - scenario_result_id=scenario_result_id, - scenario_run_state="FAILED", - error_message=error_msg, - error_type=type(atomic_results.incomplete_objectives[0][1]).__name__, - ) - - # Raise exception with detailed information - raise ValueError(error_msg) from atomic_results.incomplete_objectives[0][1] - logger.info( - f"Atomic attack {i}/{len(self._atomic_attacks)} completed successfully with " - f"{len(atomic_results.completed_results)} results" - ) - - except Exception as e: - # Exception was raised either by run_async or by our check above - logger.error( - f"Atomic attack {i}/{len(self._atomic_attacks)} " - f"('{atomic_attack.atomic_attack_name}') failed in scenario '{self._name}': {str(e)}" - ) - - # Mark scenario as failed if not already done - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) - if scenario_results and scenario_results[0].scenario_run_state != "FAILED": - self._memory.update_scenario_run_state( - scenario_result_id=scenario_result_id, - scenario_run_state="FAILED", - error_message=str(e), - error_type=type(e).__name__, - ) - - raise logger.info(f"Scenario '{self._name}' completed successfully") @@ -1339,3 +1269,174 @@ async def _execute_scenario_async(self) -> ScenarioResult: except Exception as e: logger.error(f"Scenario '{self._name}' failed with error: {str(e)}") raise + + def _partial_result_to_exception( + self, + *, + atomic_attack: AtomicAttack, + atomic_results: Any, + ) -> ValueError | None: + """ + Log the outcome of an atomic attack and return an exception if it didn't + fully complete. + + Returns: + ValueError | None: An error to raise when the atomic attack has incomplete + objectives, otherwise ``None`` when all objectives finished successfully. + """ + if not atomic_results.has_incomplete: + logger.info( + f"Atomic attack ('{atomic_attack.atomic_attack_name}') completed successfully with " + f"{len(atomic_results.completed_results)} results" + ) + return None + + incomplete_count = len(atomic_results.incomplete_objectives) + completed_in_run = len(atomic_results.completed_results) + logger.error( + f"Atomic attack ('{atomic_attack.atomic_attack_name}') partially completed: " + f"{completed_in_run} completed, {incomplete_count} incomplete" + ) + for obj, exc in atomic_results.incomplete_objectives: + logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") + + inner = atomic_results.incomplete_objectives[0][1] + error = ValueError( + f"Atomic attack '{atomic_attack.atomic_attack_name}' partially failed: " + f"{incomplete_count} of {incomplete_count + completed_in_run} objectives incomplete. " + f"See attack results for details." + ) + if isinstance(inner, BaseException): + error.__cause__ = inner + return error + + def _mark_scenario_failed(self, *, scenario_result_id: str, error: BaseException) -> None: + """Mark the scenario run as FAILED, deriving message/type from ``error``.""" + cause = error.__cause__ if error.__cause__ is not None else error + self._memory.update_scenario_run_state( + scenario_result_id=scenario_result_id, + scenario_run_state="FAILED", + error_message=str(error), + error_type=type(cause).__name__, + ) + + async def _execute_atomic_attacks_sequential_async( + self, + *, + remaining_attacks: list[AtomicAttack], + scenario_result_id: str, + completed_count: int, + ) -> None: + """ + Execute atomic attacks one at a time. First failure marks the scenario FAILED + and raises immediately. This is the default behavior preserved from prior versions. + + Raises: + ValueError: If an atomic attack returns incomplete objectives. + Exception: Re-raised from a failing atomic attack. + """ + progress = tqdm( + remaining_attacks, + desc=f"Executing {self._name}", + unit="attack", + total=len(self._atomic_attacks), + initial=completed_count, + ) + total = len(self._atomic_attacks) + + for i, atomic_attack in enumerate(progress, start=completed_count + 1): + atomic_attack.set_scenario_result_id(scenario_result_id) + logger.info( + f"Executing atomic attack {i}/{total} " + f"('{atomic_attack.atomic_attack_name}') in scenario '{self._name}'" + ) + + try: + atomic_results = await atomic_attack.run_async( + max_concurrency=self._max_concurrency, + return_partial_on_failure=True, + ) + except Exception as e: + logger.error( + f"Atomic attack {i}/{total} " + f"('{atomic_attack.atomic_attack_name}') failed in scenario '{self._name}': {str(e)}" + ) + self._mark_scenario_failed(scenario_result_id=scenario_result_id, error=e) + raise + + error = self._partial_result_to_exception( + atomic_attack=atomic_attack, atomic_results=atomic_results + ) + if error is not None: + self._mark_scenario_failed(scenario_result_id=scenario_result_id, error=error) + raise error + + async def _execute_atomic_attacks_parallel_async( + self, + *, + remaining_attacks: list[AtomicAttack], + scenario_result_id: str, + completed_count: int, + ) -> None: + """ + Execute all remaining atomic attacks concurrently. All in-flight objectives + across all atomic attacks share a single ``Semaphore(max_concurrency)`` so the + global concurrent-objective budget is bounded by ``max_concurrency`` regardless + of how work is distributed across atomic attacks. This means a long-running + atomic attack can elastically use freed slots from short ones. + + Failure semantics differ from the sequential path: when an atomic attack fails or + returns ``has_incomplete``, in-flight siblings are allowed to finish (so their + partial work persists for resume), then the first error is re-raised. + """ + shared_semaphore = asyncio.Semaphore(self._max_concurrency) + pbar = tqdm( + desc=f"Executing {self._name}", + unit="attack", + total=len(self._atomic_attacks), + initial=completed_count, + ) + + for atomic_attack in remaining_attacks: + atomic_attack.set_scenario_result_id(scenario_result_id) + + logger.info( + f"Launching {len(remaining_attacks)} atomic attacks in parallel " + f"(shared max_concurrency={self._max_concurrency}) in scenario '{self._name}'" + ) + + async def run_one(atomic_attack: AtomicAttack) -> tuple[AtomicAttack, Any]: + try: + result = await atomic_attack.run_async( + max_concurrency=self._max_concurrency, + return_partial_on_failure=True, + semaphore=shared_semaphore, + ) + return atomic_attack, result + finally: + pbar.update(1) + + try: + outcomes = await asyncio.gather( + *(run_one(aa) for aa in remaining_attacks), + return_exceptions=True, + ) + finally: + pbar.close() + + first_error: BaseException | None = None + for outcome in outcomes: + if isinstance(outcome, BaseException): + logger.error(f"Atomic attack failed in scenario '{self._name}': {str(outcome)}") + error = outcome + else: + atomic_attack, atomic_results = outcome + error = self._partial_result_to_exception( + atomic_attack=atomic_attack, atomic_results=atomic_results + ) + if error is not None and first_error is None: + first_error = error + + if first_error is not None: + self._mark_scenario_failed(scenario_result_id=scenario_result_id, error=first_error) + raise first_error diff --git a/tests/unit/scenario/test_atomic_attack.py b/tests/unit/scenario/test_atomic_attack.py index b17db749f..1a4b38b76 100644 --- a/tests/unit/scenario/test_atomic_attack.py +++ b/tests/unit/scenario/test_atomic_attack.py @@ -263,7 +263,7 @@ async def test_run_async_with_custom_concurrency(self, mock_attack, sample_seed_ result = await atomic_attack.run_async(max_concurrency=5) - mock_init.assert_called_once_with(max_concurrency=5) + mock_init.assert_called_once_with(max_concurrency=5, semaphore=None) assert len(result.completed_results) == 3 async def test_run_async_with_default_concurrency(self, mock_attack, sample_seed_groups, sample_attack_results): @@ -282,7 +282,7 @@ async def test_run_async_with_default_concurrency(self, mock_attack, sample_seed await atomic_attack.run_async() - mock_init.assert_called_once_with(max_concurrency=1) + mock_init.assert_called_once_with(max_concurrency=1, semaphore=None) async def test_run_async_passes_memory_labels(self, mock_attack, sample_seed_groups, sample_attack_results): """Test that memory labels are passed to the executor.""" diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index a4971a3e0..219b99249 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -3,8 +3,9 @@ """Tests for the scenarios.Scenario class.""" +import asyncio from typing import ClassVar -from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, PropertyMock, patch import pytest @@ -355,7 +356,7 @@ class TestScenarioExecution: """Tests for Scenario execution methods.""" async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): - """Test that run_async executes all atomic attacks sequentially.""" + """Test that run_async executes all atomic attacks.""" # Configure each run to return different results for i, run in enumerate(mock_atomic_attacks): run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) @@ -372,10 +373,14 @@ async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_att # Verify return type is ScenarioResult assert isinstance(result, ScenarioResult) - # Verify all runs were executed with correct concurrency + # Verify all runs were executed. Default max_concurrency=10 with 3 atomic attacks + # means parallel path: each atomic attack receives max_concurrency=10 and a + # shared semaphore that caps total in-flight objectives. assert len(result.attack_results) == 3 for run in mock_atomic_attacks: - run.run_async.assert_called_once_with(max_concurrency=10, return_partial_on_failure=True) + run.run_async.assert_called_once_with( + max_concurrency=10, return_partial_on_failure=True, semaphore=ANY + ) # Verify results are aggregated correctly by atomic attack name assert "attack_run_1" in result.attack_results @@ -388,7 +393,7 @@ async def test_run_async_executes_all_runs(self, mock_atomic_attacks, sample_att async def test_run_async_with_custom_concurrency( self, mock_atomic_attacks, sample_attack_results, mock_objective_target ): - """Test that max_concurrency from init is passed to each atomic attack.""" + """Test that max_concurrency from init is split across atomic attacks.""" for i, run in enumerate(mock_atomic_attacks): run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) @@ -401,9 +406,12 @@ async def test_run_async_with_custom_concurrency( result = await scenario.run_async() - # Verify max_concurrency was passed to each run + # 3 atomic attacks, max_concurrency=5 -> parallel path with shared semaphore. + # Each atomic attack still receives max_concurrency=5 (the semaphore is the real cap). for run in mock_atomic_attacks: - run.run_async.assert_called_once_with(max_concurrency=5, return_partial_on_failure=True) + run.run_async.assert_called_once_with( + max_concurrency=5, return_partial_on_failure=True, semaphore=ANY + ) # Verify result structure assert isinstance(result, ScenarioResult) @@ -441,7 +449,7 @@ async def test_run_async_aggregates_multiple_results( assert len(result.attack_results["attack_run_3"]) == 1 async def test_run_async_stops_on_error(self, mock_atomic_attacks, sample_attack_results, mock_objective_target): - """Test that execution stops when an atomic attack fails.""" + """Test that sequential execution (max_concurrency=1) stops on first failure.""" mock_atomic_attacks[0].run_async = create_mock_run_async([sample_attack_results[0]]) mock_atomic_attacks[1].run_async = AsyncMock(side_effect=Exception("Test error")) mock_atomic_attacks[2].run_async = create_mock_run_async([sample_attack_results[2]]) @@ -451,7 +459,8 @@ async def test_run_async_stops_on_error(self, mock_atomic_attacks, sample_attack version=1, atomic_attacks_to_return=mock_atomic_attacks, ) - await scenario.initialize_async(objective_target=mock_objective_target) + # Pin to sequential mode so the abort-on-first-failure semantics apply. + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=1) with pytest.raises(Exception, match="Test error"): await scenario.run_async() @@ -460,7 +469,7 @@ async def test_run_async_stops_on_error(self, mock_atomic_attacks, sample_attack mock_atomic_attacks[0].run_async.assert_called_once() # Second run should have been attempted mock_atomic_attacks[1].run_async.assert_called_once() - # Third run should not have been executed + # Third run should not have been executed (sequential aborts on first failure) mock_atomic_attacks[2].run_async.assert_not_called() async def test_run_async_fails_without_initialization(self, mock_objective_target): @@ -1194,3 +1203,210 @@ async def test_resume_raises_when_id_not_found(self, mock_objective_target, mock with pytest.raises(ValueError, match="not found in memory"): await scenario.initialize_async(objective_target=mock_objective_target) + + +@pytest.mark.usefixtures("patch_central_database") +class TestScenarioParallelExecution: + """Tests for parallel atomic-attack execution sharing a single max_concurrency budget.""" + + async def test_atomic_attacks_share_one_semaphore( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): + """All atomic attacks in parallel mode receive the same shared semaphore.""" + for i, run in enumerate(mock_atomic_attacks): + run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) + + scenario = ConcreteScenario( + name="Test Scenario", + version=1, + atomic_attacks_to_return=mock_atomic_attacks, + ) + await scenario.initialize_async( + objective_target=mock_objective_target, + max_concurrency=4, + ) + + await scenario.run_async() + + # Each atomic attack got max_concurrency=4 and a semaphore kwarg, and it's the + # same Semaphore instance across all three. + semaphores_seen = [] + for run in mock_atomic_attacks: + assert run.run_async.call_count == 1 + kwargs = run.run_async.call_args.kwargs + assert kwargs["max_concurrency"] == 4 + assert kwargs["return_partial_on_failure"] is True + assert isinstance(kwargs["semaphore"], asyncio.Semaphore) + semaphores_seen.append(kwargs["semaphore"]) + assert semaphores_seen[0] is semaphores_seen[1] is semaphores_seen[2] + + async def test_shared_semaphore_bounds_global_concurrency( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): + """Total in-flight objectives across all atomic attacks never exceeds max_concurrency. + + Simulates each atomic attack 'using' the semaphore for two objectives. With + max_concurrency=2 and 3 atomic attacks (= 6 objectives total), peak in-flight + objective count must stay <= 2 even though all three atomic attacks are launched. + """ + peak = [0] + in_flight = [0] + lock = asyncio.Lock() + + def make_run_async(idx): + async def run_async(*, semaphore, **kwargs): + # Simulate two objectives per atomic attack, each acquiring the shared sem. + for _ in range(2): + async with semaphore: + async with lock: + in_flight[0] += 1 + peak[0] = max(peak[0], in_flight[0]) + await asyncio.sleep(0.02) + async with lock: + in_flight[0] -= 1 + _stamp_scenario_linkage( + attack_results=[sample_attack_results[idx]], + atomic_attack=mock_atomic_attacks[idx], + ) + save_attack_results_to_memory([sample_attack_results[idx]]) + return AttackExecutorResult( + completed_results=[sample_attack_results[idx]], incomplete_objectives=[] + ) + + return AsyncMock(side_effect=run_async) + + for i, run in enumerate(mock_atomic_attacks): + run.run_async = make_run_async(i) + + scenario = ConcreteScenario( + name="Test Scenario", + version=1, + atomic_attacks_to_return=mock_atomic_attacks, + ) + await scenario.initialize_async( + objective_target=mock_objective_target, + max_concurrency=2, + ) + + await scenario.run_async() + + assert peak[0] <= 2, f"shared semaphore violated: peak in-flight was {peak[0]}" + assert peak[0] == 2, f"expected to saturate budget of 2, peaked at {peak[0]}" + + async def test_atomic_attacks_run_concurrently( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): + """When max_concurrency permits, multiple atomic attacks are in-flight simultaneously.""" + started = asyncio.Event() + in_flight = 0 + max_in_flight = 0 + lock = asyncio.Lock() + + def make_run_async(idx): + async def run_async(*args, **kwargs): + nonlocal in_flight, max_in_flight + async with lock: + in_flight += 1 + max_in_flight = max(max_in_flight, in_flight) + if in_flight >= 3: + started.set() + try: + await asyncio.wait_for(started.wait(), timeout=2.0) + finally: + async with lock: + in_flight -= 1 + _stamp_scenario_linkage( + attack_results=[sample_attack_results[idx]], + atomic_attack=mock_atomic_attacks[idx], + ) + save_attack_results_to_memory([sample_attack_results[idx]]) + return AttackExecutorResult( + completed_results=[sample_attack_results[idx]], incomplete_objectives=[] + ) + + return AsyncMock(side_effect=run_async) + + for i, run in enumerate(mock_atomic_attacks): + run.run_async = make_run_async(i) + + scenario = ConcreteScenario( + name="Test Scenario", + version=1, + atomic_attacks_to_return=mock_atomic_attacks, + ) + await scenario.initialize_async( + objective_target=mock_objective_target, + max_concurrency=6, + ) + + result = await scenario.run_async() + + assert max_in_flight == 3, f"expected all 3 atomic attacks in flight, peaked at {max_in_flight}" + assert len(result.attack_results) == 3 + + async def test_failure_lets_siblings_finish( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): + """When one atomic attack fails in parallel mode, in-flight siblings still complete.""" + completed_calls: list[str] = [] + + async def ok_run(idx, name): + await asyncio.sleep(0.05) + completed_calls.append(name) + _stamp_scenario_linkage( + attack_results=[sample_attack_results[idx]], + atomic_attack=mock_atomic_attacks[idx], + ) + save_attack_results_to_memory([sample_attack_results[idx]]) + return AttackExecutorResult( + completed_results=[sample_attack_results[idx]], incomplete_objectives=[] + ) + + async def bad_run(*args, **kwargs): + raise RuntimeError("boom") + + async def side_run_0(*a, **k): + return await ok_run(0, "attack_run_1") + + async def side_run_2(*a, **k): + return await ok_run(2, "attack_run_3") + + mock_atomic_attacks[0].run_async = AsyncMock(side_effect=side_run_0) + mock_atomic_attacks[1].run_async = AsyncMock(side_effect=bad_run) + mock_atomic_attacks[2].run_async = AsyncMock(side_effect=side_run_2) + + scenario = ConcreteScenario( + name="Test Scenario", + version=1, + atomic_attacks_to_return=mock_atomic_attacks, + ) + await scenario.initialize_async( + objective_target=mock_objective_target, + max_concurrency=6, + ) + + with pytest.raises(RuntimeError, match="boom"): + await scenario.run_async() + + # Both non-failing siblings ran to completion before the error was raised. + assert "attack_run_1" in completed_calls + assert "attack_run_3" in completed_calls + + async def test_sequential_when_max_concurrency_is_one( + self, mock_atomic_attacks, sample_attack_results, mock_objective_target + ): + """max_concurrency=1 forces sequential path; each attack gets max_concurrency=1 and no semaphore.""" + for i, run in enumerate(mock_atomic_attacks): + run.run_async = create_mock_run_async([sample_attack_results[i]], atomic_attack=run) + + scenario = ConcreteScenario( + name="Test Scenario", + version=1, + atomic_attacks_to_return=mock_atomic_attacks, + ) + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=1) + + await scenario.run_async() + + for run in mock_atomic_attacks: + run.run_async.assert_called_once_with(max_concurrency=1, return_partial_on_failure=True) diff --git a/tests/unit/scenario/test_scenario_retry.py b/tests/unit/scenario/test_scenario_retry.py index d26cb1ae0..3e80cd429 100644 --- a/tests/unit/scenario/test_scenario_retry.py +++ b/tests/unit/scenario/test_scenario_retry.py @@ -291,6 +291,7 @@ async def mock_run_with_retry(*args, **kwargs): ) await scenario.initialize_async( objective_target=mock_objective_target, + max_concurrency=1, max_retries=2, ) @@ -372,6 +373,7 @@ async def mock_run_with_multiple_retries(*args, **kwargs): ) await scenario.initialize_async( objective_target=mock_objective_target, + max_concurrency=1, max_retries=3, ) @@ -408,6 +410,7 @@ async def mock_run_with_logged_failure(*args, **kwargs): ) await scenario.initialize_async( objective_target=mock_objective_target, + max_concurrency=1, max_retries=1, )