From a24ec71989ac88ad8eb8a49f57c93ae44ac27943 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 12:14:05 -0700 Subject: [PATCH 01/42] feat(scenario): scaffold state-machine refactor (Phase 0 + OutcomeScorer) Land the empty-but-tested new abstractions for the scenario core refactor side by side with the existing flat-loop scenario plumbing. Nothing in pyrit/scenario/core/scenario.py changes yet; later phases wire these in. New modules: - pyrit/scenario/core/scenario_state.py: ScenarioCoreState enum (UNINITIALIZED, INITIALIZING, EXECUTING, COMPLETE, FAILED) plus ScenarioStateLike runtime-checkable protocol. Per-scenario state enums extend the vocabulary by satisfying the protocol. - pyrit/scenario/core/scenario_step.py: ScenarioStep(Identifiable) ABC plus frozen ScenarioStepResult dataclass. One step owns one outcome decision (may wrap N attack executions). - pyrit/scenario/core/strategy_graph.py: generic StrategyGraph orchestrator over a policy dict[state, async-action]. Restartable event_loop_async yields ScenarioStepResults; history tracked for resume. Constructor validates terminal_states, initial_state, and policy/terminal overlap. - pyrit/score/decorators/outcome_scorer.py: OutcomeScorer composition wrapper around a Scorer. resolve_outcome_async returns the first matching label from outcome_map, or the 'unscored' sentinel. Not a Scorer subclass on purpose (composition keeps the Scorer ABC's validator and abstract methods out of the way). - pyrit/identifiers/step_identifier.py: build_step_identifier factory plus STEP_EVAL_VERSION constant. Composite identifier wraps N atomic_attack_identifiers under children['attack_executions']. atomic_attack_identifier is unchanged: step identity is additive. Exports: - pyrit.identifiers re-exports build_step_identifier, STEP_EVAL_VERSION - pyrit.score re-exports OutcomeScorer Tests (44 new, all green): construction validation, identifier determinism, event-loop traversal, restartability across exceptions, outcome resolution and ordering, unscored fallback for empty score lists and unmatched predicates. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/identifiers/__init__.py | 3 + pyrit/identifiers/step_identifier.py | 90 ++++++++ pyrit/scenario/core/scenario_state.py | 61 +++++ pyrit/scenario/core/scenario_step.py | 104 +++++++++ pyrit/scenario/core/strategy_graph.py | 187 +++++++++++++++ pyrit/score/__init__.py | 2 + pyrit/score/decorators/__init__.py | 8 + pyrit/score/decorators/outcome_scorer.py | 121 ++++++++++ .../unit/identifiers/test_step_identifier.py | 145 ++++++++++++ tests/unit/scenario/test_scenario_state.py | 32 +++ tests/unit/scenario/test_scenario_step.py | 69 ++++++ tests/unit/scenario/test_strategy_graph.py | 217 ++++++++++++++++++ .../score/decorators/test_outcome_scorer.py | 157 +++++++++++++ 13 files changed, 1196 insertions(+) create mode 100644 pyrit/identifiers/step_identifier.py create mode 100644 pyrit/scenario/core/scenario_state.py create mode 100644 pyrit/scenario/core/scenario_step.py create mode 100644 pyrit/scenario/core/strategy_graph.py create mode 100644 pyrit/score/decorators/__init__.py create mode 100644 pyrit/score/decorators/outcome_scorer.py create mode 100644 tests/unit/identifiers/test_step_identifier.py create mode 100644 tests/unit/scenario/test_scenario_state.py create mode 100644 tests/unit/scenario/test_scenario_step.py create mode 100644 tests/unit/scenario/test_strategy_graph.py create mode 100644 tests/unit/score/decorators/test_outcome_scorer.py diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index a85c2caca..c8e46b5ae 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -22,11 +22,13 @@ compute_eval_hash, ) from pyrit.identifiers.identifier_filters import IdentifierFilter, IdentifierType +from pyrit.identifiers.step_identifier import STEP_EVAL_VERSION, build_step_identifier __all__ = [ "AtomicAttackEvaluationIdentifier", "build_atomic_attack_identifier", "build_seed_identifier", + "build_step_identifier", "ChildEvalRule", "class_name_to_snake_case", "ComponentIdentifier", @@ -36,6 +38,7 @@ "REGISTRY_NAME_PATTERN", "ScorerEvaluationIdentifier", "snake_case_to_class_name", + "STEP_EVAL_VERSION", "validate_registry_name", "config_hash", "IdentifierFilter", diff --git a/pyrit/identifiers/step_identifier.py b/pyrit/identifiers/step_identifier.py new file mode 100644 index 000000000..d1e9b7d59 --- /dev/null +++ b/pyrit/identifiers/step_identifier.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario step identity builder. + +Builds a composite ``ComponentIdentifier`` that wraps one or more +``atomic_attack_identifier``s with the step's name and emitted outcome label. +This is the *step-level* identity that the scenario state machine produces; +the underlying ``atomic_attack_identifier`` is **not** renamed or removed and +continues to live on every ``AttackResult`` exactly as before. + +Bump ``STEP_EVAL_VERSION`` when changing what is included in the step +identifier; this lets ``StepEvaluationIdentifier`` (Phase 4) detect schema +drift without falsely conflating old and new hashes. + +Composite shape:: + + ScenarioStep + ├── step_name (param) + ├── outcome (param) + ├── eval_version (param) + └── attack_executions (children, list of atomic_attack_identifier) +""" + +from __future__ import annotations + +import logging +from typing import Any + +from pyrit.identifiers.component_identifier import ComponentIdentifier + +logger = logging.getLogger(__name__) + +#: Schema version for ``build_step_identifier``. Bump on any change that +#: affects which params or children participate in the identity (and thus the +#: derived hash / eval hash). Stays inside the identifier's ``params`` so old +#: rows preserve their original version. +STEP_EVAL_VERSION: int = 1 + +_SCENARIO_STEP_CLASS_NAME = "ScenarioStep" +_SCENARIO_STEP_CLASS_MODULE = "pyrit.scenario.core.scenario_step" + + +def build_step_identifier( + *, + step_name: str, + outcome: str, + attack_execution_identifiers: list[ComponentIdentifier], +) -> ComponentIdentifier: + """ + Build a composite ``ComponentIdentifier`` for one step execution. + + Args: + step_name (str): The step's ``name`` attribute. + outcome (str): The transition label the step emitted. + attack_execution_identifiers (list[ComponentIdentifier]): The + ``atomic_attack_identifier``s produced by every attack execution + the step ran, in execution order. May be empty for steps that + record outcomes from external signals. + + Returns: + ComponentIdentifier: Composite identifier with + ``class_name="ScenarioStep"`` and the attack executions nested + under ``children["attack_executions"]``. + + Raises: + ValueError: If ``step_name`` or ``outcome`` is empty. + """ + if not step_name: + raise ValueError("step_name must be non-empty.") + if not outcome: + raise ValueError("outcome must be non-empty.") + + params: dict[str, Any] = { + "step_name": step_name, + "outcome": outcome, + "eval_version": STEP_EVAL_VERSION, + } + + children: dict[str, Any] = { + "attack_executions": list(attack_execution_identifiers), + } + + return ComponentIdentifier( + class_name=_SCENARIO_STEP_CLASS_NAME, + class_module=_SCENARIO_STEP_CLASS_MODULE, + params=params, + children=children, + ) diff --git a/pyrit/scenario/core/scenario_state.py b/pyrit/scenario/core/scenario_state.py new file mode 100644 index 000000000..9e18c3400 --- /dev/null +++ b/pyrit/scenario/core/scenario_state.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario state vocabulary for the state-machine refactor. + +``ScenarioCoreState`` provides the small set of states every scenario shares: +lifecycle transitions before any work runs, while work is running, and after +work has finished or failed. + +Scenarios that need richer states (e.g., ``OPENING_PHASE``, ``ESCALATING``) +declare their own enum that satisfies the ``ScenarioStateLike`` protocol. + +This module is part of the scenario core refactor scaffold (Phase 0). It is +not yet wired into ``Scenario.run_async``; ``StrategyGraph`` in Phase 3 +consumes it. +""" + +from __future__ import annotations + +from enum import Enum +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class ScenarioStateLike(Protocol): + """ + Structural marker for any enum used as a scenario state. + + Both ``ScenarioCoreState`` and per-scenario state enums satisfy this + protocol. ``StrategyGraph`` accepts any value that is hashable and has a + string ``name`` (the standard ``Enum`` contract), so the only constraint + is that scenarios declare their states as enums. + """ + + name: str + value: object + + +class ScenarioCoreState(Enum): + """ + Lifecycle states shared by every scenario. + + Per-scenario state enums extend this vocabulary by declaring their own + enum class with additional members. + """ + + #: Scenario object constructed but ``initialize_async`` has not run. + UNINITIALIZED = "uninitialized" + + #: ``initialize_async`` is running (pre-flight, graph build, state init). + INITIALIZING = "initializing" + + #: At least one step has started; graph traversal is active. + EXECUTING = "executing" + + #: All steps finished and the graph reached a terminal accepting state. + COMPLETE = "complete" + + #: A step raised or the graph reached a terminal rejecting state. + FAILED = "failed" diff --git a/pyrit/scenario/core/scenario_step.py b/pyrit/scenario/core/scenario_step.py new file mode 100644 index 000000000..8137327a7 --- /dev/null +++ b/pyrit/scenario/core/scenario_step.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``ScenarioStep`` — one outcome decision in a scenario's state machine. + +A ``ScenarioStep`` is strictly coarser than an ``AtomicAttack``: it may invoke +multiple attack-technique executions whose collective results determine a +single transition label emitted to the surrounding ``StrategyGraph``'s +policy. + +This module is part of the scenario core refactor scaffold (Phase 0). The ABC +is in place but no scenario consumes it yet; ``LinearAtomicStep`` (Phase 2) +becomes the first concrete implementation and ``AtomicAttack`` is then +re-rooted onto it for backward compatibility. +""" + +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from pyrit.identifiers import ComponentIdentifier, Identifiable + +if TYPE_CHECKING: + from pyrit.models import AttackResult + + +@dataclass(frozen=True) +class ScenarioStepResult: + """ + Outcome of a single ``ScenarioStep.process_async`` invocation. + + Attributes: + outcome (str): The transition label emitted to the surrounding + ``StrategyGraph``'s policy. Must be one of the step's declared + ``outputs``. + attack_results (list[AttackResult]): Every ``AttackResult`` produced + during the step's execution, in the order they were produced. + May be empty for steps that record outcomes from external signals + rather than running attacks. + step_identifier (ComponentIdentifier | None): The composite identifier + for this step execution. ``None`` until Phase 4 lands the + ``step_identifier`` persistence column; populated after that for + scenarios that opt into graph-based execution. + """ + + outcome: str + attack_results: list[AttackResult] = field(default_factory=list) + step_identifier: ComponentIdentifier | None = None + + +class ScenarioStep(Identifiable): + """ + Abstract base for one outcome decision in a scenario. + + Subclasses declare their valid input and output vocabulary, then implement + ``process_async`` to execute attacks (or any other work) and return a + ``ScenarioStepResult`` whose ``outcome`` is the transition label the + surrounding ``StrategyGraph`` will use to advance state. + + Step granularity rule: one step owns *one* outcome decision. A step may + wrap multiple attack-technique executions whose collective results + determine that one transition label. + """ + + #: Display / resume key. Must be unique within an execution graph. + name: str + + #: Declared transition labels this step can emit. ``"unscored"`` is + #: implicitly always allowed for steps that use an ``OutcomeScorer``. + outputs: list[str] + + @abstractmethod + async def process_async(self) -> ScenarioStepResult: + """ + Execute the step's work and return the outcome plus any attack results. + + Returns: + ScenarioStepResult: The transition label and any attack results + produced. The ``outcome`` must be one of ``self.outputs`` (or + the implicit ``"unscored"`` sentinel from ``OutcomeScorer``). + """ + ... + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for this step. + + Default implementation captures the step name and declared outputs. + Subclasses should override to include their inputs (e.g., attack + technique identifiers) and any other behavioral params. + + Returns: + ComponentIdentifier: The frozen identity snapshot. + """ + return ComponentIdentifier.of( + self, + params={ + "name": self.name, + "outputs": list(self.outputs), + }, + ) diff --git a/pyrit/scenario/core/strategy_graph.py b/pyrit/scenario/core/strategy_graph.py new file mode 100644 index 000000000..da178d2da --- /dev/null +++ b/pyrit/scenario/core/strategy_graph.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``StrategyGraph`` — the orchestrator side of the scenario state machine. + +A ``StrategyGraph`` owns a ``policy: dict[state, action]`` mapping and walks +through states by invoking the action bound to the current state. Each action +returns a ``(next_state, ScenarioStepResult | None)`` pair; the graph yields +the result and advances. Terminal states stop the loop. + +This module is part of the scenario core refactor scaffold (Phase 0). The +skeleton supports straight-line execution; richer policy composition (graph +validation, cycle detection, parallel branches) lands as Phase 3 needs it. +""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import TYPE_CHECKING, Generic, TypeVar + +from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike + +if TYPE_CHECKING: + from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult + +logger = logging.getLogger(__name__) + +StepT = TypeVar("StepT", bound="ScenarioStep") +StateT = TypeVar("StateT", bound=ScenarioStateLike) + +#: A policy action receives the graph and returns the next state plus the +#: result produced (or ``None`` if the action did no observable work). +PolicyAction = Callable[ + ["StrategyGraph[StepT, StateT]"], + Awaitable[tuple[StateT, "ScenarioStepResult | None"]], +] + + +class StrategyGraph(Generic[StepT, StateT]): + """ + Policy-driven state machine over ``ScenarioStep``s. + + Construct with a ``policy`` dict mapping each non-terminal state to an + async callable that returns the next state and (optionally) the step + result produced. ``event_loop_async`` iterates the graph: at each state + it invokes the bound action, yields any returned result, and advances to + the next state until a terminal state is reached. + + The graph maintains ``current_state``, ``current_step``, and ``history`` + so that retries can resume from the last persisted state without + replaying completed work. + """ + + def __init__( + self, + *, + policy: dict[StateT, PolicyAction[StepT, StateT]], + initial_state: StateT, + terminal_states: set[StateT], + ) -> None: + """ + Initialize a ``StrategyGraph``. + + Args: + policy (dict[StateT, PolicyAction[StepT, StateT]]): Mapping from + each non-terminal state to the async action that fires while + in that state. + initial_state (StateT): Starting state. Must not be in + ``terminal_states`` (a graph that starts in a terminal state + does no work and is almost certainly a bug). + terminal_states (set[StateT]): States that stop ``event_loop_async`` + when reached. Must be non-empty. + + Raises: + ValueError: If ``terminal_states`` is empty, ``initial_state`` + is terminal, or a non-terminal state lacks a policy entry. + """ + if not terminal_states: + raise ValueError("StrategyGraph requires at least one terminal state.") + if initial_state in terminal_states: + raise ValueError( + f"initial_state {initial_state!r} is in terminal_states; the graph would do no work." + ) + + missing_policy = [state for state in policy if state in terminal_states] + if missing_policy: + raise ValueError( + f"Terminal states must not appear in policy: {missing_policy!r}." + ) + + self._policy = dict(policy) + self._initial_state = initial_state + self._terminal_states = set(terminal_states) + self._current_state: StateT = initial_state + self._current_step: StepT | None = None + self._history: list[tuple[StateT, ScenarioStepResult]] = [] + + @property + def current_state(self) -> StateT: + """Return the graph's current state.""" + return self._current_state + + @property + def current_step(self) -> StepT | None: + """Return the step bound to the current state, if the action set one.""" + return self._current_step + + @property + def history(self) -> list[tuple[StateT, ScenarioStepResult]]: + """Return the ordered history of ``(state, result)`` pairs produced so far.""" + return list(self._history) + + @property + def is_terminal(self) -> bool: + """Return ``True`` if the graph is in a terminal state.""" + return self._current_state in self._terminal_states + + def bind_current_step(self, step: StepT | None) -> None: + """ + Set the step bound to the current state. + + Policy actions call this so external observers (e.g., the surrounding + ``Scenario``) can read ``graph.current_step`` while the action runs. + + Args: + step (StepT | None): The step the action is about to execute, or + ``None`` to clear. + """ + self._current_step = step + + def reset(self) -> None: + """ + Reset the graph back to ``initial_state`` and clear history. + + Used by retry paths that want a clean slate rather than resuming + from the last persisted state. + """ + self._current_state = self._initial_state + self._current_step = None + self._history = [] + + async def event_loop_async(self) -> AsyncIterator[ScenarioStepResult]: + """ + Walk the policy graph, yielding each non-null ``ScenarioStepResult``. + + Restartable: callers may resume from the current state after an + exception or external interruption. Each iteration: + + 1. Looks up the action for ``current_state``. + 2. Awaits the action to receive ``(next_state, result)``. + 3. Appends ``(state_before, result)`` to history when ``result`` is + non-null. + 4. Advances ``current_state`` to ``next_state``. + 5. Stops when ``current_state`` becomes terminal. + + Yields: + ScenarioStepResult: Each non-null result produced by a policy + action, in execution order. + + Raises: + KeyError: If the graph reaches a non-terminal state with no + policy entry (indicates malformed policy). + """ + while not self.is_terminal: + state_before = self._current_state + action = self._policy.get(state_before) + if action is None: + raise KeyError( + f"StrategyGraph reached non-terminal state {state_before!r} " + f"with no policy entry." + ) + + next_state, result = await action(self) + if result is not None: + self._history.append((state_before, result)) + yield result + + self._current_state = next_state + + +__all__ = [ + "StrategyGraph", + "PolicyAction", + "ScenarioCoreState", +] diff --git a/pyrit/score/__init__.py b/pyrit/score/__init__.py index dfdafdda4..e90252d71 100644 --- a/pyrit/score/__init__.py +++ b/pyrit/score/__init__.py @@ -13,6 +13,7 @@ from pyrit.output.scorer.pretty import PrettyScorerMemoryPrinter as ConsoleScorerPrinter from pyrit.score.batch_scorer import BatchScorer from pyrit.score.conversation_scorer import ConversationScorer, create_conversation_scorer +from pyrit.score.decorators import OutcomeScorer from pyrit.score.float_scale.azure_content_filter_scorer import AzureContentFilterScorer from pyrit.score.float_scale.float_scale_score_aggregator import ( FloatScaleScoreAggregator, @@ -135,6 +136,7 @@ def __getattr__(name: str) -> object: "ObjectiveHumanLabeledEntry", "ObjectiveScorerEvaluator", "ObjectiveScorerMetrics", + "OutcomeScorer", "PlagiarismMetric", "PlagiarismScorer", "PromptShieldScorer", diff --git a/pyrit/score/decorators/__init__.py b/pyrit/score/decorators/__init__.py new file mode 100644 index 000000000..a5c368327 --- /dev/null +++ b/pyrit/score/decorators/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Scorer decorators that wrap a ``Scorer`` to add behavior without subclassing.""" + +from pyrit.score.decorators.outcome_scorer import OutcomeScorer + +__all__ = ["OutcomeScorer"] diff --git a/pyrit/score/decorators/outcome_scorer.py b/pyrit/score/decorators/outcome_scorer.py new file mode 100644 index 000000000..dd2d40b80 --- /dev/null +++ b/pyrit/score/decorators/outcome_scorer.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``OutcomeScorer`` — wrap a ``Scorer`` to resolve its output into a transition label. + +This is the bridge between scorer outputs and state-machine transitions for +the scenario core refactor. It is a composition wrapper, not a ``Scorer`` +subclass: callers that need the full ``Scorer`` interface use +``OutcomeScorer.wrapped_scorer`` directly. + +Part of the scenario core refactor scaffold (Phase 1). ``ScenarioStep`` +implementations consume ``OutcomeScorer.resolve_outcome_async`` to map a +freshly produced response into one of the step's declared ``outputs``. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, ClassVar + +if TYPE_CHECKING: + from pyrit.models import Message, Score + from pyrit.score import Scorer + + +class OutcomeScorer: + """ + Decorator that maps a wrapped ``Scorer``'s output to a transition label. + + The ``outcome_map`` is an ordered ``dict`` of ``label -> predicate``. The + first label whose predicate returns ``True`` for any produced ``Score`` is + emitted. If no predicate matches (or the wrapped scorer returns no + scores), the sentinel ``"unscored"`` is emitted; the surrounding policy + can declare an explicit ``"unscored" -> `` transition if needed. + """ + + #: Sentinel label emitted when no entry in ``outcome_map`` matches and + #: when the wrapped scorer returns no scores at all. + UNSCORED: ClassVar[str] = "unscored" + + def __init__( + self, + *, + wrapped_scorer: Scorer, + outcome_map: dict[str, Callable[[Score], bool]], + ) -> None: + """ + Initialize an ``OutcomeScorer``. + + Args: + wrapped_scorer (Scorer): The scorer whose output drives outcome + resolution. The wrapper does not modify the scorer; it only + calls ``score_async`` and maps the result. + outcome_map (dict[str, Callable[[Score], bool]]): Ordered mapping + from transition label to a predicate over ``Score``. The first + label whose predicate matches any produced score is emitted. + Reserved label ``"unscored"`` must not appear here. + + Raises: + ValueError: If ``outcome_map`` is empty or contains the reserved + ``"unscored"`` label. + """ + if not outcome_map: + raise ValueError("OutcomeScorer requires a non-empty outcome_map.") + if self.UNSCORED in outcome_map: + raise ValueError( + f"Label {self.UNSCORED!r} is reserved as the no-match sentinel " + f"and may not appear in outcome_map." + ) + + self._wrapped_scorer = wrapped_scorer + self._outcome_map = dict(outcome_map) + + @property + def wrapped_scorer(self) -> Scorer: + """Return the wrapped scorer for callers needing the full Scorer interface.""" + return self._wrapped_scorer + + @property + def outcomes(self) -> list[str]: + """ + Return the full list of labels this scorer can emit. + + Includes the implicit ``"unscored"`` sentinel at the end, so step + validation (``set(outcomes).issubset(step.outputs)``) catches missing + ``"unscored"`` declarations early. + + Returns: + list[str]: Ordered labels, with ``"unscored"`` last. + """ + return [*self._outcome_map.keys(), self.UNSCORED] + + async def resolve_outcome_async( + self, + message: Message, + *, + objective: str | None = None, + ) -> str: + """ + Score ``message`` with the wrapped scorer and return the first matching label. + + Args: + message (Message): The message to score. + objective (str | None): Optional objective forwarded to the + wrapped scorer's ``score_async``. + + Returns: + str: The first matching label from ``outcome_map``, or + ``"unscored"`` if no entry matches or the wrapped scorer + returned no scores. + """ + scores = await self._wrapped_scorer.score_async(message, objective=objective) + if not scores: + return self.UNSCORED + + for label, predicate in self._outcome_map.items(): + if any(predicate(score) for score in scores): + return label + + return self.UNSCORED diff --git a/tests/unit/identifiers/test_step_identifier.py b/tests/unit/identifiers/test_step_identifier.py new file mode 100644 index 000000000..9fac1e67e --- /dev/null +++ b/tests/unit/identifiers/test_step_identifier.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.identifiers.step_identifier``.""" + +import pytest + +from pyrit.identifiers import ( + STEP_EVAL_VERSION, + ComponentIdentifier, + build_step_identifier, +) + + +def _make_atomic_identifier(*, hash_suffix: str = "a") -> ComponentIdentifier: + return ComponentIdentifier( + class_name="AtomicAttack", + class_module="pyrit.scenario.core.atomic_attack", + params={"marker": hash_suffix}, + ) + + +def test_build_step_identifier_returns_component_identifier(): + result = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[], + ) + assert isinstance(result, ComponentIdentifier) + + +def test_class_name_is_scenario_step(): + result = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[], + ) + assert result.class_name == "ScenarioStep" + assert result.class_module == "pyrit.scenario.core.scenario_step" + + +def test_params_include_step_name_outcome_and_version(): + result = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[], + ) + assert result.params["step_name"] == "opening" + assert result.params["outcome"] == "violation" + assert result.params["eval_version"] == STEP_EVAL_VERSION + + +def test_attack_executions_are_nested_under_children(): + atomic_a = _make_atomic_identifier(hash_suffix="a") + atomic_b = _make_atomic_identifier(hash_suffix="b") + + result = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic_a, atomic_b], + ) + + assert "attack_executions" in result.children + nested = result.children["attack_executions"] + assert isinstance(nested, list) + assert nested[0].params["marker"] == "a" + assert nested[1].params["marker"] == "b" + + +def test_empty_attack_executions_allowed(): + """Steps that record outcomes from external signals may produce no attacks.""" + result = build_step_identifier( + step_name="external_signal", + outcome="received", + attack_execution_identifiers=[], + ) + assert result.children["attack_executions"] == [] + + +def test_rejects_empty_step_name(): + with pytest.raises(ValueError, match="step_name must be non-empty"): + build_step_identifier( + step_name="", + outcome="violation", + attack_execution_identifiers=[], + ) + + +def test_rejects_empty_outcome(): + with pytest.raises(ValueError, match="outcome must be non-empty"): + build_step_identifier( + step_name="opening", + outcome="", + attack_execution_identifiers=[], + ) + + +def test_hash_is_deterministic_for_same_inputs(): + atomic = _make_atomic_identifier() + a = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic], + ) + b = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic], + ) + assert a.hash == b.hash + + +def test_hash_differs_when_outcome_differs(): + atomic = _make_atomic_identifier() + violation = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic], + ) + refusal = build_step_identifier( + step_name="opening", + outcome="refusal", + attack_execution_identifiers=[atomic], + ) + assert violation.hash != refusal.hash + + +def test_hash_differs_when_step_name_differs(): + atomic = _make_atomic_identifier() + opening = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic], + ) + escalating = build_step_identifier( + step_name="escalating", + outcome="violation", + attack_execution_identifiers=[atomic], + ) + assert opening.hash != escalating.hash + + +def test_step_eval_version_is_positive_int(): + assert isinstance(STEP_EVAL_VERSION, int) + assert STEP_EVAL_VERSION >= 1 diff --git a/tests/unit/scenario/test_scenario_state.py b/tests/unit/scenario/test_scenario_state.py new file mode 100644 index 000000000..5b3e69162 --- /dev/null +++ b/tests/unit/scenario/test_scenario_state.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.scenario.core.scenario_state``.""" + +from enum import Enum + +from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike + + +def test_scenario_core_state_has_required_members(): + names = {member.name for member in ScenarioCoreState} + assert names == {"UNINITIALIZED", "INITIALIZING", "EXECUTING", "COMPLETE", "FAILED"} + + +def test_scenario_core_state_satisfies_protocol(): + # Runtime-checkable protocol verifies presence of `name` and `value`. + assert isinstance(ScenarioCoreState.UNINITIALIZED, ScenarioStateLike) + + +def test_per_scenario_subclass_enum_also_satisfies_protocol(): + class MyScenarioState(Enum): + OPENING_PHASE = "opening_phase" + ESCALATING = "escalating" + + assert isinstance(MyScenarioState.OPENING_PHASE, ScenarioStateLike) + + +def test_values_are_lower_snake_case_strings(): + for member in ScenarioCoreState: + assert isinstance(member.value, str) + assert member.value == member.name.lower() diff --git a/tests/unit/scenario/test_scenario_step.py b/tests/unit/scenario/test_scenario_step.py new file mode 100644 index 000000000..7a86a7009 --- /dev/null +++ b/tests/unit/scenario/test_scenario_step.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.scenario.core.scenario_step``.""" + +import pytest + +from pyrit.identifiers import ComponentIdentifier +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult + + +class _ConcreteStep(ScenarioStep): + """Minimal concrete step that returns a fixed outcome.""" + + def __init__(self, *, name: str = "test_step", outputs: list[str] | None = None): + self.name = name + self.outputs = outputs or ["done"] + + async def process_async(self) -> ScenarioStepResult: + return ScenarioStepResult(outcome="done") + + +def test_scenario_step_is_abstract(): + with pytest.raises(TypeError): + # Cannot instantiate the ABC directly because ``process_async`` is abstract. + ScenarioStep() # type: ignore[abstract] + + +def test_concrete_step_can_be_instantiated(): + step = _ConcreteStep() + assert step.name == "test_step" + assert step.outputs == ["done"] + + +async def test_process_async_returns_step_result(): + step = _ConcreteStep() + result = await step.process_async() + assert isinstance(result, ScenarioStepResult) + assert result.outcome == "done" + assert result.attack_results == [] + assert result.step_identifier is None + + +def test_get_identifier_includes_name_and_outputs(): + step = _ConcreteStep(name="opening", outputs=["pass", "fail"]) + identifier = step.get_identifier() + assert isinstance(identifier, ComponentIdentifier) + assert identifier.params["name"] == "opening" + assert identifier.params["outputs"] == ["pass", "fail"] + assert identifier.class_name == "_ConcreteStep" + + +def test_get_identifier_is_cached(): + step = _ConcreteStep() + first = step.get_identifier() + second = step.get_identifier() + assert first is second + + +def test_step_result_is_frozen(): + result = ScenarioStepResult(outcome="done") + with pytest.raises(Exception): # frozen dataclass raises FrozenInstanceError + result.outcome = "other" # type: ignore[misc] + + +def test_step_result_defaults(): + result = ScenarioStepResult(outcome="done") + assert result.attack_results == [] + assert result.step_identifier is None diff --git a/tests/unit/scenario/test_strategy_graph.py b/tests/unit/scenario/test_strategy_graph.py new file mode 100644 index 000000000..8021a9acf --- /dev/null +++ b/tests/unit/scenario/test_strategy_graph.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.scenario.core.strategy_graph``.""" + +from enum import Enum + +import pytest + +from pyrit.scenario.core.scenario_state import ScenarioCoreState +from pyrit.scenario.core.scenario_step import ScenarioStepResult +from pyrit.scenario.core.strategy_graph import StrategyGraph + + +class _State(Enum): + START = "start" + MIDDLE = "middle" + END = "end" + + +def test_init_requires_non_empty_terminal_states(): + with pytest.raises(ValueError, match="at least one terminal state"): + StrategyGraph( + policy={}, + initial_state=_State.START, + terminal_states=set(), + ) + + +def test_init_rejects_initial_state_in_terminals(): + with pytest.raises(ValueError, match="would do no work"): + StrategyGraph( + policy={}, + initial_state=_State.END, + terminal_states={_State.END}, + ) + + +def test_init_rejects_terminal_state_with_policy_entry(): + async def _action(graph): + return _State.END, None + + with pytest.raises(ValueError, match="Terminal states must not appear in policy"): + StrategyGraph( + policy={_State.END: _action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + +async def test_event_loop_yields_results_in_order(): + async def start_action(graph): + return _State.MIDDLE, ScenarioStepResult(outcome="from_start") + + async def middle_action(graph): + return _State.END, ScenarioStepResult(outcome="from_middle") + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: start_action, _State.MIDDLE: middle_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + results = [r async for r in graph.event_loop_async()] + assert [r.outcome for r in results] == ["from_start", "from_middle"] + + +async def test_event_loop_advances_current_state(): + async def start_action(graph): + return _State.END, ScenarioStepResult(outcome="done") + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: start_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + assert graph.current_state == _State.START + _ = [r async for r in graph.event_loop_async()] + assert graph.current_state == _State.END + assert graph.is_terminal + + +async def test_event_loop_skips_yield_when_action_returns_no_result(): + async def silent_action(graph): + return _State.END, None + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: silent_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + results = [r async for r in graph.event_loop_async()] + assert results == [] + assert graph.current_state == _State.END + + +async def test_event_loop_raises_on_missing_policy_entry(): + async def start_action(graph): + # Transition to MIDDLE but the graph has no policy entry for MIDDLE. + return _State.MIDDLE, None + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: start_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + with pytest.raises(KeyError, match="no policy entry"): + _ = [r async for r in graph.event_loop_async()] + + +async def test_history_records_yielded_results_only(): + async def start_action(graph): + return _State.MIDDLE, ScenarioStepResult(outcome="recorded") + + async def middle_action(graph): + return _State.END, None + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: start_action, _State.MIDDLE: middle_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + _ = [r async for r in graph.event_loop_async()] + assert len(graph.history) == 1 + state_before, result = graph.history[0] + assert state_before == _State.START + assert result.outcome == "recorded" + + +async def test_reset_returns_graph_to_initial_state(): + async def start_action(graph): + return _State.END, ScenarioStepResult(outcome="done") + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: start_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + _ = [r async for r in graph.event_loop_async()] + assert graph.current_state == _State.END + assert len(graph.history) == 1 + + graph.reset() + assert graph.current_state == _State.START + assert graph.history == [] + assert graph.current_step is None + + +def test_bind_current_step_sets_and_clears(): + async def noop_action(graph): + return _State.END, None + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: noop_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + graph.bind_current_step(None) + assert graph.current_step is None + + sentinel = object() + graph.bind_current_step(sentinel) # type: ignore[arg-type] + assert graph.current_step is sentinel + + +async def test_event_loop_is_restartable_from_current_state(): + """After an exception, the loop can be re-entered from the failed state.""" + call_count = {"n": 0} + + async def flaky_action(graph): + call_count["n"] += 1 + if call_count["n"] == 1: + raise RuntimeError("first call fails") + return _State.END, ScenarioStepResult(outcome="recovered") + + graph: StrategyGraph = StrategyGraph( + policy={_State.START: flaky_action}, + initial_state=_State.START, + terminal_states={_State.END}, + ) + + # First attempt blows up; graph stays at START. + with pytest.raises(RuntimeError, match="first call fails"): + _ = [r async for r in graph.event_loop_async()] + assert graph.current_state == _State.START + + # Retry: the loop restarts from current_state and now succeeds. + results = [r async for r in graph.event_loop_async()] + assert [r.outcome for r in results] == ["recovered"] + assert graph.current_state == _State.END + + +async def test_strategy_graph_works_with_scenario_core_state(): + async def initializing_action(graph): + return ScenarioCoreState.EXECUTING, None + + async def executing_action(graph): + return ScenarioCoreState.COMPLETE, ScenarioStepResult(outcome="finished") + + graph: StrategyGraph = StrategyGraph( + policy={ + ScenarioCoreState.INITIALIZING: initializing_action, + ScenarioCoreState.EXECUTING: executing_action, + }, + initial_state=ScenarioCoreState.INITIALIZING, + terminal_states={ScenarioCoreState.COMPLETE, ScenarioCoreState.FAILED}, + ) + + results = [r async for r in graph.event_loop_async()] + assert [r.outcome for r in results] == ["finished"] + assert graph.current_state == ScenarioCoreState.COMPLETE diff --git a/tests/unit/score/decorators/test_outcome_scorer.py b/tests/unit/score/decorators/test_outcome_scorer.py new file mode 100644 index 000000000..46dc78a28 --- /dev/null +++ b/tests/unit/score/decorators/test_outcome_scorer.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``pyrit.score.decorators.outcome_scorer``.""" + +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import Message, MessagePiece, Score +from pyrit.score import Scorer +from pyrit.score.decorators import OutcomeScorer + + +def _make_score(*, value: str, score_type: str = "true_false") -> Score: + return Score( + score_value=value, + score_value_description="", + score_type=score_type, # type: ignore[arg-type] + score_rationale="", + message_piece_id=str(uuid4()), + scorer_class_identifier=ComponentIdentifier( + class_name="MockScorer", + class_module="tests.unit.score.decorators.test_outcome_scorer", + ), + ) + + +def _make_message(*, text: str = "hello") -> Message: + return Message(message_pieces=[MessagePiece(role="user", original_value=text)]) + + +def test_init_rejects_empty_outcome_map(): + scorer = MagicMock(spec=Scorer) + with pytest.raises(ValueError, match="non-empty outcome_map"): + OutcomeScorer(wrapped_scorer=scorer, outcome_map={}) + + +def test_init_rejects_reserved_unscored_label(): + scorer = MagicMock(spec=Scorer) + with pytest.raises(ValueError, match="reserved"): + OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={"unscored": lambda s: True}, + ) + + +def test_wrapped_scorer_is_exposed(): + inner = MagicMock(spec=Scorer) + outer = OutcomeScorer( + wrapped_scorer=inner, + outcome_map={"hit": lambda s: True}, + ) + assert outer.wrapped_scorer is inner + + +def test_outcomes_includes_unscored_last(): + scorer = MagicMock(spec=Scorer) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={ + "violation": lambda s: s.score_value == "true", + "refusal": lambda s: s.score_value == "false", + }, + ) + assert outer.outcomes == ["violation", "refusal", "unscored"] + + +async def test_resolve_outcome_returns_first_matching_label(): + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[_make_score(value="true")]) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={ + "violation": lambda s: s.score_value == "true", + "refusal": lambda s: s.score_value == "false", + }, + ) + + label = await outer.resolve_outcome_async(_make_message()) + assert label == "violation" + + +async def test_resolve_outcome_preserves_outcome_map_order(): + """When multiple predicates match, the first label wins.""" + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[_make_score(value="true")]) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={ + "first": lambda s: True, + "second": lambda s: True, + }, + ) + + label = await outer.resolve_outcome_async(_make_message()) + assert label == "first" + + +async def test_resolve_outcome_returns_unscored_when_no_predicate_matches(): + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[_make_score(value="false")]) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={ + "violation": lambda s: s.score_value == "true", + }, + ) + + label = await outer.resolve_outcome_async(_make_message()) + assert label == "unscored" + + +async def test_resolve_outcome_returns_unscored_when_no_scores_produced(): + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[]) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={"hit": lambda s: True}, + ) + + label = await outer.resolve_outcome_async(_make_message()) + assert label == "unscored" + + +async def test_resolve_outcome_forwards_objective_to_wrapped_scorer(): + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[_make_score(value="true")]) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={"hit": lambda s: True}, + ) + + message = _make_message() + await outer.resolve_outcome_async(message, objective="break the model") + scorer.score_async.assert_called_once() + assert scorer.score_async.call_args.kwargs["objective"] == "break the model" + + +async def test_resolve_outcome_matches_against_any_score_in_list(): + """Predicate firing on *any* score in the list is enough.""" + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock( + return_value=[ + _make_score(value="false"), + _make_score(value="false"), + _make_score(value="true"), + ] + ) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={"hit": lambda s: s.score_value == "true"}, + ) + label = await outer.resolve_outcome_async(_make_message()) + assert label == "hit" From 9201e4a20b3a285453b51a1038054cbe8cde432e Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 12:30:49 -0700 Subject: [PATCH 02/42] feat(scenario): land ScenarioStep adapter for AtomicAttack (Phase 2) Aligns the Phase 0 scaffold with the codebase's policy patterns used by TargetCapabilities / TargetRequirements / ScorerOverridePolicy: - ScenarioCoreState now inherits (str, Enum) like CapabilityName and ScorerOverridePolicy, keeping state values JSON-serializable for resume payloads. - New frozen StrategyPolicy dataclass wraps actions / initial_state / terminal_states with MappingProxyType defensive copy and a keyword-only get_action(*, state=...) / is_terminal(*, state=...) lookup API, mirroring CapabilityHandlingPolicy.behaviors / get_behavior. - StrategyGraph is reduced to a thin orchestrator that consumes a single StrategyPolicy. Construction-time validation moved onto StrategyPolicy.__post_init__ so the policy is its own typed invariant. - bind_current_step(*, step=...) is now keyword-only. AtomicAttack inherits from ScenarioStep: - name property aliases atomic_attack_name (the resume / dedup key). - outputs returns a defensive copy of the single hard-coded `done` transition label. - process_async wraps run_async into a ScenarioStepResult; incomplete_objectives and input_indices ride in result.metadata so the orchestrator (Phase 5) can consume them without forcing every step to invent its own payload type. - _build_identifier nests the underlying AttackTechnique identifier under children. ScenarioStepResult gains a metadata: dict[str, Any] field so steps can carry per-step bookkeeping (incomplete objectives, adaptive selector state, etc.) without polluting the outcome label. Tests: 13 new ScenarioStep-contract tests for AtomicAttack and a full rewrite of test_strategy_graph.py to construct via StrategyPolicy. Scoped suite (tests/unit/scenario tests/unit/identifiers tests/unit/score) green: 1825 passed, 15 skipped. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/__init__.py | 14 ++ pyrit/scenario/core/__init__.py | 10 + pyrit/scenario/core/atomic_attack.py | 78 ++++++- pyrit/scenario/core/scenario_state.py | 7 +- pyrit/scenario/core/scenario_step.py | 10 +- pyrit/scenario/core/strategy_graph.py | 167 +++++++++----- pyrit/score/decorators/outcome_scorer.py | 3 +- .../unit/identifiers/test_step_identifier.py | 2 + .../test_atomic_attack_scenario_step.py | 215 ++++++++++++++++++ tests/unit/scenario/test_scenario_step.py | 2 +- tests/unit/scenario/test_strategy_graph.py | 205 ++++++++++++----- .../score/decorators/test_outcome_scorer.py | 2 +- 12 files changed, 592 insertions(+), 123 deletions(-) create mode 100644 tests/unit/scenario/test_atomic_attack_scenario_step.py diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index b66539543..680e8423a 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -23,9 +23,16 @@ AttackTechniqueFactory, BaselinePolicy, DatasetConfiguration, + PolicyAction, Scenario, ScenarioCompositeStrategy, + ScenarioCoreState, + ScenarioStateLike, + ScenarioStep, + ScenarioStepResult, ScenarioStrategy, + StrategyGraph, + StrategyPolicy, ) # Import scenario submodules directly and register them as virtual subpackages @@ -54,11 +61,18 @@ "BaselinePolicy", "DatasetConfiguration", "Parameter", + "PolicyAction", "Scenario", "ScenarioCompositeStrategy", + "ScenarioCoreState", + "ScenarioStateLike", + "ScenarioStep", + "ScenarioStepResult", "ScenarioStrategy", "ScenarioIdentifier", "ScenarioResult", + "StrategyGraph", + "StrategyPolicy", "airt", "benchmark", "garak", diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 89c8935da..8fb32f303 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -9,12 +9,15 @@ from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory, ScorerOverridePolicy from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration from pyrit.scenario.core.scenario import BaselinePolicy, Scenario +from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target, get_default_scorer_target from pyrit.scenario.core.scenario_techniques import ( SCENARIO_TECHNIQUES, register_scenario_techniques, ) +from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy __all__ = [ "AtomicAttack", @@ -23,12 +26,19 @@ "BaselinePolicy", "DatasetConfiguration", "EXPLICIT_SEED_GROUPS_KEY", + "PolicyAction", "SCENARIO_TECHNIQUES", "Parameter", "Scenario", "ScenarioCompositeStrategy", + "ScenarioCoreState", + "ScenarioStateLike", + "ScenarioStep", + "ScenarioStepResult", "ScenarioStrategy", "ScorerOverridePolicy", + "StrategyGraph", + "StrategyPolicy", "register_scenario_techniques", "get_default_scorer_target", "get_default_adversarial_target", diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index ef61b8b0b..be4b7eefa 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -19,12 +19,13 @@ from pyrit.common.deprecation import print_deprecation_message from pyrit.executor.attack import AttackExecutor, AttackStrategy from pyrit.executor.attack.core.attack_executor import AttackExecutorResult -from pyrit.identifiers import build_atomic_attack_identifier +from pyrit.identifiers import ComponentIdentifier, build_atomic_attack_identifier from pyrit.identifiers.evaluation_identifier import AtomicAttackEvaluationIdentifier from pyrit.memory import CentralMemory from pyrit.memory.memory_models import MAX_IDENTIFIER_VALUE_LENGTH from pyrit.models import AttackResult, SeedAttackGroup from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult if TYPE_CHECKING: from pyrit.prompt_target import PromptTarget @@ -33,7 +34,7 @@ logger = logging.getLogger(__name__) -class AtomicAttack: +class AtomicAttack(ScenarioStep): """ Represents a single atomic attack test combining an attack strategy and dataset. @@ -47,8 +48,19 @@ class AtomicAttack: An ``AttackTechnique`` bundles the attack strategy with an optional ``SeedAttackTechniqueGroup``, cleanly separating "how to attack" from "what to attack" (the objective). + + Implements the ``ScenarioStep`` contract so the upcoming ``StrategyGraph`` + orchestrator can drive AtomicAttack the same way it drives richer steps: + one ``process_async`` call yields one ``ScenarioStepResult``. AtomicAttack + always emits the single ``"done"`` outcome — graph-aware scenarios that + want richer transitions implement their own ``ScenarioStep`` subclass. """ + #: Hard-coded outcome label for graph-based scenarios. AtomicAttack does + #: no outcome scoring of its own; richer steps override ``outputs`` to + #: declare scorer-driven transition labels. + _OUTPUTS: tuple[str, ...] = ("done",) + def __init__( self, *, @@ -152,6 +164,68 @@ def seed_groups(self) -> list[SeedAttackGroup]: """ return list(self._seed_groups) + @property + def name(self) -> str: + """ + Display / resume key for this atomic attack, satisfying ``ScenarioStep``. + + Aliases ``atomic_attack_name`` so existing code continues using the + original attribute while the ``StrategyGraph`` orchestrator reads + ``name`` uniformly across all step types. + """ + return self.atomic_attack_name + + @property + def outputs(self) -> list[str]: + """Transition labels this step can emit. AtomicAttack emits only ``"done"``.""" + return list(self._OUTPUTS) + + async def process_async(self) -> ScenarioStepResult: + """ + ``ScenarioStep`` adapter — runs the atomic attack and wraps the result. + + Delegates to ``run_async`` using the instance's stored execution + parameters, then packages the completed results into a + ``ScenarioStepResult``. Incomplete objectives and the executor's + ``input_indices`` are stashed in ``metadata`` so the orchestrator + (Phase 5) can drive resume / retry logic without losing information. + + Returns: + ScenarioStepResult: ``outcome="done"`` with the completed attack + results and execution bookkeeping in ``metadata``. + """ + executor_result = await self.run_async() + return ScenarioStepResult( + outcome="done", + attack_results=list(executor_result.completed_results), + metadata={ + "incomplete_objectives": list(executor_result.incomplete_objectives), + "input_indices": list(executor_result.input_indices), + }, + ) + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for this atomic attack. + + Captures the atomic attack name (the resume / dedup key) and nests + the underlying ``AttackTechnique`` identifier so hash drift in the + attack or its seeds propagates upward. + + Returns: + ComponentIdentifier: Identifier whose ``params`` carry the step + name and declared outputs, with the underlying attack identifier + nested under ``children``. + """ + return ComponentIdentifier.of( + self, + params={ + "atomic_attack_name": self.atomic_attack_name, + "outputs": list(self.outputs), + }, + children={"attack_technique": self._attack_technique.get_identifier()}, + ) + def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) -> None: """ Filter seed groups to only those with objectives in the remaining list. diff --git a/pyrit/scenario/core/scenario_state.py b/pyrit/scenario/core/scenario_state.py index 9e18c3400..67ad5481a 100644 --- a/pyrit/scenario/core/scenario_state.py +++ b/pyrit/scenario/core/scenario_state.py @@ -37,10 +37,15 @@ class ScenarioStateLike(Protocol): value: object -class ScenarioCoreState(Enum): +class ScenarioCoreState(str, Enum): """ Lifecycle states shared by every scenario. + Inherits from ``str`` to match the codebase convention for canonical + identifier enums (``CapabilityName``, ``ScorerOverridePolicy``), keeping + state values JSON-serializable for resume payloads without explicit + coercion. + Per-scenario state enums extend this vocabulary by declaring their own enum class with additional members. """ diff --git a/pyrit/scenario/core/scenario_step.py b/pyrit/scenario/core/scenario_step.py index 8137327a7..e7f13397f 100644 --- a/pyrit/scenario/core/scenario_step.py +++ b/pyrit/scenario/core/scenario_step.py @@ -19,7 +19,7 @@ from abc import abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pyrit.identifiers import ComponentIdentifier, Identifiable @@ -44,11 +44,19 @@ class ScenarioStepResult: for this step execution. ``None`` until Phase 4 lands the ``step_identifier`` persistence column; populated after that for scenarios that opt into graph-based execution. + metadata (dict[str, Any]): Free-form metadata produced by the step. + Used to carry per-step bookkeeping that isn't part of the outcome + label itself — e.g., ``incomplete_objectives`` from a partial + ``AttackExecutorResult``, or selector state for adaptive steps. + The orchestrator in Phase 5 reads from this dict to drive + resume / retry logic without forcing every step to invent its own + payload type. """ outcome: str attack_results: list[AttackResult] = field(default_factory=list) step_identifier: ComponentIdentifier | None = None + metadata: dict[str, Any] = field(default_factory=dict) class ScenarioStep(Identifiable): diff --git a/pyrit/scenario/core/strategy_graph.py b/pyrit/scenario/core/strategy_graph.py index da178d2da..ac446186d 100644 --- a/pyrit/scenario/core/strategy_graph.py +++ b/pyrit/scenario/core/strategy_graph.py @@ -17,18 +17,18 @@ from __future__ import annotations import logging -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable, Hashable, Mapping +from dataclasses import dataclass, field +from types import MappingProxyType from typing import TYPE_CHECKING, Generic, TypeVar -from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike - if TYPE_CHECKING: from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult logger = logging.getLogger(__name__) StepT = TypeVar("StepT", bound="ScenarioStep") -StateT = TypeVar("StateT", bound=ScenarioStateLike) +StateT = TypeVar("StateT", bound=Hashable) #: A policy action receives the graph and returns the next state plus the #: result produced (or ``None`` if the action did no observable work). @@ -38,15 +38,99 @@ ] +@dataclass(frozen=True) +class StrategyPolicy(Generic[StepT, StateT]): + """ + Frozen declaration of how a ``StrategyGraph`` traverses its states. + + Mirrors the codebase's other policy wrappers (``CapabilityHandlingPolicy``, + ``TargetRequirements``): a typed wrapper around a mapping plus the + structural invariants that mapping must satisfy. Bundling the action + map, ``initial_state``, and ``terminal_states`` into one frozen value + means the policy can be constructed once at scenario class definition + time and shared across runs without risk of mutation. + + The action mapping is defensively copied into a ``MappingProxyType`` + in ``__post_init__`` so the policy is genuinely read-only. + + Attributes: + actions (Mapping[StateT, PolicyAction[StepT, StateT]]): Per-state + async callable that produces the next state plus an optional + ``ScenarioStepResult``. + initial_state (StateT): The state the graph starts in. Must not be + in ``terminal_states`` (a graph that starts terminal does no + work and is almost certainly a bug). + terminal_states (frozenset[StateT]): States that stop the event + loop when reached. Must be non-empty. Terminal states must not + appear in ``actions``. + """ + + actions: Mapping[StateT, PolicyAction[StepT, StateT]] + initial_state: StateT + terminal_states: frozenset[StateT] = field(default_factory=frozenset) + + def __post_init__(self) -> None: + """ + Validate the policy structure and freeze the action map. + + Raises: + ValueError: If ``terminal_states`` is empty, ``initial_state`` + is itself a terminal state, or a terminal state appears as + a key in ``actions`` (which would never fire). + """ + if not self.terminal_states: + raise ValueError("StrategyPolicy requires at least one terminal state.") + if self.initial_state in self.terminal_states: + raise ValueError( + f"initial_state {self.initial_state!r} is in terminal_states; " + f"the graph would do no work." + ) + + overlap = [state for state in self.actions if state in self.terminal_states] + if overlap: + raise ValueError( + f"Terminal states must not appear in actions: {overlap!r}." + ) + + object.__setattr__(self, "actions", MappingProxyType(dict(self.actions))) + object.__setattr__(self, "terminal_states", frozenset(self.terminal_states)) + + def get_action(self, *, state: StateT) -> PolicyAction[StepT, StateT]: + """ + Return the action bound to ``state``. + + Args: + state (StateT): The state to look up. + + Returns: + PolicyAction[StepT, StateT]: The configured async action. + + Raises: + KeyError: If ``state`` is non-terminal and has no entry in + ``actions`` (indicates a malformed policy). + """ + try: + return self.actions[state] + except KeyError: + known = ", ".join(sorted(str(s) for s in self.actions)) + raise KeyError( + f"No action defined for state {state!r}. Known states: {known or '(none)'}." + ) from None + + def is_terminal(self, *, state: StateT) -> bool: + """Return ``True`` if ``state`` is a terminal state.""" + return state in self.terminal_states + + class StrategyGraph(Generic[StepT, StateT]): """ Policy-driven state machine over ``ScenarioStep``s. - Construct with a ``policy`` dict mapping each non-terminal state to an - async callable that returns the next state and (optionally) the step - result produced. ``event_loop_async`` iterates the graph: at each state - it invokes the bound action, yields any returned result, and advances to - the next state until a terminal state is reached. + Construct with a frozen ``StrategyPolicy`` describing the action map, + starting state, and terminal states. ``event_loop_async`` iterates the + graph: at each state it invokes the bound action, yields any returned + result, and advances to the next state until a terminal state is + reached. The graph maintains ``current_state``, ``current_step``, and ``history`` so that retries can resume from the last persisted state without @@ -56,47 +140,27 @@ class StrategyGraph(Generic[StepT, StateT]): def __init__( self, *, - policy: dict[StateT, PolicyAction[StepT, StateT]], - initial_state: StateT, - terminal_states: set[StateT], + policy: StrategyPolicy[StepT, StateT], ) -> None: """ - Initialize a ``StrategyGraph``. + Initialize a ``StrategyGraph`` from a frozen policy. Args: - policy (dict[StateT, PolicyAction[StepT, StateT]]): Mapping from - each non-terminal state to the async action that fires while - in that state. - initial_state (StateT): Starting state. Must not be in - ``terminal_states`` (a graph that starts in a terminal state - does no work and is almost certainly a bug). - terminal_states (set[StateT]): States that stop ``event_loop_async`` - when reached. Must be non-empty. - - Raises: - ValueError: If ``terminal_states`` is empty, ``initial_state`` - is terminal, or a non-terminal state lacks a policy entry. + policy (StrategyPolicy[StepT, StateT]): Frozen policy describing + the action map, initial state, and terminal states. All + structural validation lives on ``StrategyPolicy``; the graph + trusts the policy it receives. """ - if not terminal_states: - raise ValueError("StrategyGraph requires at least one terminal state.") - if initial_state in terminal_states: - raise ValueError( - f"initial_state {initial_state!r} is in terminal_states; the graph would do no work." - ) - - missing_policy = [state for state in policy if state in terminal_states] - if missing_policy: - raise ValueError( - f"Terminal states must not appear in policy: {missing_policy!r}." - ) - - self._policy = dict(policy) - self._initial_state = initial_state - self._terminal_states = set(terminal_states) - self._current_state: StateT = initial_state + self._policy = policy + self._current_state: StateT = policy.initial_state self._current_step: StepT | None = None self._history: list[tuple[StateT, ScenarioStepResult]] = [] + @property + def policy(self) -> StrategyPolicy[StepT, StateT]: + """Return the frozen policy that drives this graph.""" + return self._policy + @property def current_state(self) -> StateT: """Return the graph's current state.""" @@ -115,9 +179,9 @@ def history(self) -> list[tuple[StateT, ScenarioStepResult]]: @property def is_terminal(self) -> bool: """Return ``True`` if the graph is in a terminal state.""" - return self._current_state in self._terminal_states + return self._policy.is_terminal(state=self._current_state) - def bind_current_step(self, step: StepT | None) -> None: + def bind_current_step(self, *, step: StepT | None) -> None: """ Set the step bound to the current state. @@ -132,12 +196,12 @@ def bind_current_step(self, step: StepT | None) -> None: def reset(self) -> None: """ - Reset the graph back to ``initial_state`` and clear history. + Reset the graph back to ``policy.initial_state`` and clear history. Used by retry paths that want a clean slate rather than resuming from the last persisted state. """ - self._current_state = self._initial_state + self._current_state = self._policy.initial_state self._current_step = None self._history = [] @@ -148,7 +212,7 @@ async def event_loop_async(self) -> AsyncIterator[ScenarioStepResult]: Restartable: callers may resume from the current state after an exception or external interruption. Each iteration: - 1. Looks up the action for ``current_state``. + 1. Looks up the action for ``current_state`` via the policy. 2. Awaits the action to receive ``(next_state, result)``. 3. Appends ``(state_before, result)`` to history when ``result`` is non-null. @@ -165,12 +229,7 @@ async def event_loop_async(self) -> AsyncIterator[ScenarioStepResult]: """ while not self.is_terminal: state_before = self._current_state - action = self._policy.get(state_before) - if action is None: - raise KeyError( - f"StrategyGraph reached non-terminal state {state_before!r} " - f"with no policy entry." - ) + action = self._policy.get_action(state=state_before) next_state, result = await action(self) if result is not None: @@ -182,6 +241,6 @@ async def event_loop_async(self) -> AsyncIterator[ScenarioStepResult]: __all__ = [ "StrategyGraph", + "StrategyPolicy", "PolicyAction", - "ScenarioCoreState", ] diff --git a/pyrit/score/decorators/outcome_scorer.py b/pyrit/score/decorators/outcome_scorer.py index dd2d40b80..9bce0eedd 100644 --- a/pyrit/score/decorators/outcome_scorer.py +++ b/pyrit/score/decorators/outcome_scorer.py @@ -16,10 +16,11 @@ from __future__ import annotations -from collections.abc import Callable from typing import TYPE_CHECKING, ClassVar if TYPE_CHECKING: + from collections.abc import Callable + from pyrit.models import Message, Score from pyrit.score import Scorer diff --git a/tests/unit/identifiers/test_step_identifier.py b/tests/unit/identifiers/test_step_identifier.py index 9fac1e67e..13e8229cb 100644 --- a/tests/unit/identifiers/test_step_identifier.py +++ b/tests/unit/identifiers/test_step_identifier.py @@ -63,6 +63,8 @@ def test_attack_executions_are_nested_under_children(): assert "attack_executions" in result.children nested = result.children["attack_executions"] assert isinstance(nested, list) + assert isinstance(nested[0], ComponentIdentifier) + assert isinstance(nested[1], ComponentIdentifier) assert nested[0].params["marker"] == "a" assert nested[1].params["marker"] == "b" diff --git a/tests/unit/scenario/test_atomic_attack_scenario_step.py b/tests/unit/scenario/test_atomic_attack_scenario_step.py new file mode 100644 index 000000000..c3593a720 --- /dev/null +++ b/tests/unit/scenario/test_atomic_attack_scenario_step.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for ``AtomicAttack`` as a ``ScenarioStep`` (Phase 2 adapter). + +Verifies that ``AtomicAttack`` satisfies the ``ScenarioStep`` contract so +the upcoming ``StrategyGraph`` orchestrator can drive it uniformly with +richer step types. Existing ``AtomicAttack`` behavior is covered by +``test_atomic_attack.py``. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.executor.attack import AttackExecutor, AttackStrategy +from pyrit.executor.attack.core import AttackExecutorResult +from pyrit.identifiers import ComponentIdentifier, Identifiable +from pyrit.models import ( + AttackOutcome, + AttackResult, + SeedAttackGroup, + SeedObjective, + SeedPrompt, +) +from pyrit.scenario import AtomicAttack +from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult + + +@pytest.fixture +def mock_attack(): + return MagicMock(spec=AttackStrategy) + + +@pytest.fixture +def seed_groups(): + return [ + SeedAttackGroup(seeds=[SeedObjective(value="obj1"), SeedPrompt(value="p1")]), + SeedAttackGroup(seeds=[SeedObjective(value="obj2"), SeedPrompt(value="p2")]), + ] + + +@pytest.fixture +def attack_results(): + return [ + AttackResult( + conversation_id="conv-1", + objective="obj1", + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ), + AttackResult( + conversation_id="conv-2", + objective="obj2", + outcome=AttackOutcome.FAILURE, + executed_turns=1, + ), + ] + + +def test_atomic_attack_is_scenario_step(): + assert issubclass(AtomicAttack, ScenarioStep) + + +def test_atomic_attack_is_identifiable(): + assert issubclass(AtomicAttack, Identifiable) + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackScenarioStepShape: + """Static shape: name/outputs/identifier behavior independent of execution.""" + + def test_name_aliases_atomic_attack_name(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + assert atomic.name == "my_step" + assert atomic.name == atomic.atomic_attack_name + + def test_outputs_is_done_only(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + assert atomic.outputs == ["done"] + + def test_outputs_property_returns_fresh_list(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + first = atomic.outputs + first.append("extra") + # Mutation of the returned list must not affect the next read. + assert atomic.outputs == ["done"] + + def test_get_identifier_returns_component_identifier(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + identifier = atomic.get_identifier() + assert isinstance(identifier, ComponentIdentifier) + assert identifier.params["atomic_attack_name"] == "my_step" + assert identifier.params["outputs"] == ["done"] + assert "attack_technique" in identifier.children + + def test_get_identifier_is_cached(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + first = atomic.get_identifier() + second = atomic.get_identifier() + assert first is second + + def test_identifier_hash_differs_when_name_differs(self, mock_attack, seed_groups): + atomic_a = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="step_a", + ) + atomic_b = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="step_b", + ) + assert atomic_a.get_identifier().hash != atomic_b.get_identifier().hash + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackProcessAsync: + """``process_async`` wraps ``run_async`` into a ``ScenarioStepResult``.""" + + async def test_returns_scenario_step_result_with_done_outcome( + self, mock_attack, seed_groups, attack_results + ): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.return_value = AttackExecutorResult( + completed_results=attack_results, + incomplete_objectives=[], + input_indices=[0, 1], + ) + result = await atomic.process_async() + + assert isinstance(result, ScenarioStepResult) + assert result.outcome == "done" + assert result.attack_results == attack_results + + async def test_metadata_carries_incomplete_objectives( + self, mock_attack, seed_groups, attack_results + ): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + + incomplete: list[tuple[str, BaseException]] = [("obj_failed", RuntimeError("boom"))] + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.return_value = AttackExecutorResult( + completed_results=attack_results[:1], + incomplete_objectives=incomplete, + input_indices=[0], + ) + result = await atomic.process_async() + + assert result.metadata["incomplete_objectives"] == incomplete + assert result.metadata["input_indices"] == [0] + + async def test_propagates_run_async_failures(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.side_effect = RuntimeError("execution exploded") + # run_async wraps the underlying exception in a ValueError; process_async surfaces it. + with pytest.raises(ValueError, match="Failed to execute atomic attack"): + await atomic.process_async() + + async def test_returns_empty_results_when_no_completions(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.return_value = AttackExecutorResult( + completed_results=[], + incomplete_objectives=[("obj1", RuntimeError("fail"))], + input_indices=[], + ) + result = await atomic.process_async() + + assert result.outcome == "done" + assert result.attack_results == [] + assert len(result.metadata["incomplete_objectives"]) == 1 diff --git a/tests/unit/scenario/test_scenario_step.py b/tests/unit/scenario/test_scenario_step.py index 7a86a7009..914882b64 100644 --- a/tests/unit/scenario/test_scenario_step.py +++ b/tests/unit/scenario/test_scenario_step.py @@ -60,7 +60,7 @@ def test_get_identifier_is_cached(): def test_step_result_is_frozen(): result = ScenarioStepResult(outcome="done") with pytest.raises(Exception): # frozen dataclass raises FrozenInstanceError - result.outcome = "other" # type: ignore[misc] + result.outcome = "other" # type: ignore[ty:invalid-assignment] def test_step_result_defaults(): diff --git a/tests/unit/scenario/test_strategy_graph.py b/tests/unit/scenario/test_strategy_graph.py index 8021a9acf..b74e213e8 100644 --- a/tests/unit/scenario/test_strategy_graph.py +++ b/tests/unit/scenario/test_strategy_graph.py @@ -9,7 +9,7 @@ from pyrit.scenario.core.scenario_state import ScenarioCoreState from pyrit.scenario.core.scenario_step import ScenarioStepResult -from pyrit.scenario.core.strategy_graph import StrategyGraph +from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy class _State(Enum): @@ -18,34 +18,97 @@ class _State(Enum): END = "end" -def test_init_requires_non_empty_terminal_states(): - with pytest.raises(ValueError, match="at least one terminal state"): - StrategyGraph( - policy={}, +def _make_graph(*, policy: StrategyPolicy) -> StrategyGraph: + return StrategyGraph(policy=policy) + + +# --------------------------------------------------------------------------- +# StrategyPolicy construction & validation +# --------------------------------------------------------------------------- + + +class TestStrategyPolicyInit: + + def test_requires_non_empty_terminal_states(self): + with pytest.raises(ValueError, match="at least one terminal state"): + StrategyPolicy( + actions={}, + initial_state=_State.START, + terminal_states=frozenset(), + ) + + def test_rejects_initial_state_in_terminals(self): + with pytest.raises(ValueError, match="would do no work"): + StrategyPolicy( + actions={}, + initial_state=_State.END, + terminal_states=frozenset({_State.END}), + ) + + def test_rejects_terminal_state_with_action_entry(self): + async def _action(graph): + return _State.END, None + + with pytest.raises(ValueError, match="Terminal states must not appear"): + StrategyPolicy( + actions={_State.END: _action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ) + + def test_actions_mapping_is_read_only(self): + async def _action(graph): + return _State.END, None + + policy = StrategyPolicy( + actions={_State.START: _action}, initial_state=_State.START, - terminal_states=set(), + terminal_states=frozenset({_State.END}), ) + with pytest.raises(TypeError): + policy.actions[_State.MIDDLE] = _action # type: ignore[ty:invalid-assignment] + + def test_get_action_returns_configured_action(self): + async def _action(graph): + return _State.END, None -def test_init_rejects_initial_state_in_terminals(): - with pytest.raises(ValueError, match="would do no work"): - StrategyGraph( - policy={}, - initial_state=_State.END, - terminal_states={_State.END}, + policy = StrategyPolicy( + actions={_State.START: _action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), ) + assert policy.get_action(state=_State.START) is _action + def test_get_action_raises_informative_error_for_unknown_state(self): + async def _action(graph): + return _State.END, None -def test_init_rejects_terminal_state_with_policy_entry(): - async def _action(graph): - return _State.END, None + policy = StrategyPolicy( + actions={_State.START: _action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ) + + with pytest.raises(KeyError, match="No action defined for state"): + policy.get_action(state=_State.MIDDLE) - with pytest.raises(ValueError, match="Terminal states must not appear in policy"): - StrategyGraph( - policy={_State.END: _action}, + def test_is_terminal_predicate(self): + async def _action(graph): + return _State.END, None + + policy = StrategyPolicy( + actions={_State.START: _action}, initial_state=_State.START, - terminal_states={_State.END}, + terminal_states=frozenset({_State.END}), ) + assert policy.is_terminal(state=_State.END) + assert not policy.is_terminal(state=_State.START) + + +# --------------------------------------------------------------------------- +# StrategyGraph traversal +# --------------------------------------------------------------------------- async def test_event_loop_yields_results_in_order(): @@ -55,10 +118,12 @@ async def start_action(graph): async def middle_action(graph): return _State.END, ScenarioStepResult(outcome="from_middle") - graph: StrategyGraph = StrategyGraph( - policy={_State.START: start_action, _State.MIDDLE: middle_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: start_action, _State.MIDDLE: middle_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) results = [r async for r in graph.event_loop_async()] @@ -69,10 +134,12 @@ async def test_event_loop_advances_current_state(): async def start_action(graph): return _State.END, ScenarioStepResult(outcome="done") - graph: StrategyGraph = StrategyGraph( - policy={_State.START: start_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: start_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) assert graph.current_state == _State.START @@ -85,10 +152,12 @@ async def test_event_loop_skips_yield_when_action_returns_no_result(): async def silent_action(graph): return _State.END, None - graph: StrategyGraph = StrategyGraph( - policy={_State.START: silent_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: silent_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) results = [r async for r in graph.event_loop_async()] @@ -96,18 +165,20 @@ async def silent_action(graph): assert graph.current_state == _State.END -async def test_event_loop_raises_on_missing_policy_entry(): +async def test_event_loop_raises_on_missing_action_entry(): async def start_action(graph): # Transition to MIDDLE but the graph has no policy entry for MIDDLE. return _State.MIDDLE, None - graph: StrategyGraph = StrategyGraph( - policy={_State.START: start_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: start_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) - with pytest.raises(KeyError, match="no policy entry"): + with pytest.raises(KeyError, match="No action defined for state"): _ = [r async for r in graph.event_loop_async()] @@ -118,10 +189,12 @@ async def start_action(graph): async def middle_action(graph): return _State.END, None - graph: StrategyGraph = StrategyGraph( - policy={_State.START: start_action, _State.MIDDLE: middle_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: start_action, _State.MIDDLE: middle_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) _ = [r async for r in graph.event_loop_async()] @@ -135,10 +208,12 @@ async def test_reset_returns_graph_to_initial_state(): async def start_action(graph): return _State.END, ScenarioStepResult(outcome="done") - graph: StrategyGraph = StrategyGraph( - policy={_State.START: start_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: start_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) _ = [r async for r in graph.event_loop_async()] @@ -155,17 +230,19 @@ def test_bind_current_step_sets_and_clears(): async def noop_action(graph): return _State.END, None - graph: StrategyGraph = StrategyGraph( - policy={_State.START: noop_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: noop_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) - graph.bind_current_step(None) + graph.bind_current_step(step=None) assert graph.current_step is None sentinel = object() - graph.bind_current_step(sentinel) # type: ignore[arg-type] + graph.bind_current_step(step=sentinel) # type: ignore[arg-type] assert graph.current_step is sentinel @@ -179,10 +256,12 @@ async def flaky_action(graph): raise RuntimeError("first call fails") return _State.END, ScenarioStepResult(outcome="recovered") - graph: StrategyGraph = StrategyGraph( - policy={_State.START: flaky_action}, - initial_state=_State.START, - terminal_states={_State.END}, + graph = _make_graph( + policy=StrategyPolicy( + actions={_State.START: flaky_action}, + initial_state=_State.START, + terminal_states=frozenset({_State.END}), + ), ) # First attempt blows up; graph stays at START. @@ -203,13 +282,15 @@ async def initializing_action(graph): async def executing_action(graph): return ScenarioCoreState.COMPLETE, ScenarioStepResult(outcome="finished") - graph: StrategyGraph = StrategyGraph( - policy={ - ScenarioCoreState.INITIALIZING: initializing_action, - ScenarioCoreState.EXECUTING: executing_action, - }, - initial_state=ScenarioCoreState.INITIALIZING, - terminal_states={ScenarioCoreState.COMPLETE, ScenarioCoreState.FAILED}, + graph = _make_graph( + policy=StrategyPolicy( + actions={ + ScenarioCoreState.INITIALIZING: initializing_action, + ScenarioCoreState.EXECUTING: executing_action, + }, + initial_state=ScenarioCoreState.INITIALIZING, + terminal_states=frozenset({ScenarioCoreState.COMPLETE, ScenarioCoreState.FAILED}), + ), ) results = [r async for r in graph.event_loop_async()] diff --git a/tests/unit/score/decorators/test_outcome_scorer.py b/tests/unit/score/decorators/test_outcome_scorer.py index 46dc78a28..a7e60ef94 100644 --- a/tests/unit/score/decorators/test_outcome_scorer.py +++ b/tests/unit/score/decorators/test_outcome_scorer.py @@ -18,7 +18,7 @@ def _make_score(*, value: str, score_type: str = "true_false") -> Score: return Score( score_value=value, score_value_description="", - score_type=score_type, # type: ignore[arg-type] + score_type=score_type, # type: ignore[ty:invalid-argument-type] score_rationale="", message_piece_id=str(uuid4()), scorer_class_identifier=ComponentIdentifier( From 56b7eba5ae3d14239948e7e74a19c884cd1e77d0 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 12:38:01 -0700 Subject: [PATCH 03/42] feat(scenario): add linear_strategy_policy builder + branching graph coverage (Phase 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Completes Phase 3 of the scenario-core refactor by adding the convenience policy builder and the branching-graph proof-of-concept tests called out in the rubber-duck pass. linear_strategy_policy(steps): - Produces a StrategyPolicy[ScenarioStep, int] that walks an ordered list of steps state-by-state, with action i binding steps[i] as current_step, awaiting its process_async, and transitioning to state i+1. State len(steps) is the sole terminal state. - Captures step / next_state via default-argument binding to dodge the classic late-binding closure bug in for-loops. - Always clears current_step in a finally so a step raising mid-execution doesn't leave the graph in an inconsistent state — the graph stays at the failed state so the existing retry loop can re-enter. - This is the policy Phase 5 will use to silently upgrade legacy scenarios that still declare their steps via _get_atomic_attacks_async. test_linear_strategy_policy.py (6 tests): - Locks the silent-upgrade contract: order preservation, binding lifecycle, late-binding bug guard, finally-clear on failure, and the empty-input guardrail. test_strategy_graph_branching.py (4 tests): - Forces the policy API through a non-trivial branching scenario (BroadSweepThenDeepDive) before Phase 5 commits to it: opening phase emits safe or violation; safe short-circuits to COMPLETE, violation routes through ESCALATION_PHASE first. - Confirms that history records both branch states, that escalation step metadata survives the round trip, and that graph.reset() correctly replays the branching path. Full unit suite: 7929 passed, 118 skipped (the one CLI test failure is the pre-existing ODBC driver missing on this host — unrelated to the refactor). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/__init__.py | 2 + pyrit/scenario/core/__init__.py | 3 +- pyrit/scenario/core/strategy_graph.py | 62 ++++++- .../scenario/test_linear_strategy_policy.py | 123 ++++++++++++++ .../scenario/test_strategy_graph_branching.py | 155 ++++++++++++++++++ 5 files changed, 340 insertions(+), 5 deletions(-) create mode 100644 tests/unit/scenario/test_linear_strategy_policy.py create mode 100644 tests/unit/scenario/test_strategy_graph_branching.py diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 680e8423a..c3e6956e2 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -33,6 +33,7 @@ ScenarioStrategy, StrategyGraph, StrategyPolicy, + linear_strategy_policy, ) # Import scenario submodules directly and register them as virtual subpackages @@ -73,6 +74,7 @@ "ScenarioResult", "StrategyGraph", "StrategyPolicy", + "linear_strategy_policy", "airt", "benchmark", "garak", diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 8fb32f303..8aec1f77d 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -17,7 +17,7 @@ SCENARIO_TECHNIQUES, register_scenario_techniques, ) -from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy +from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy, linear_strategy_policy __all__ = [ "AtomicAttack", @@ -39,6 +39,7 @@ "ScorerOverridePolicy", "StrategyGraph", "StrategyPolicy", + "linear_strategy_policy", "register_scenario_techniques", "get_default_scorer_target", "get_default_adversarial_target", diff --git a/pyrit/scenario/core/strategy_graph.py b/pyrit/scenario/core/strategy_graph.py index ac446186d..f426f041a 100644 --- a/pyrit/scenario/core/strategy_graph.py +++ b/pyrit/scenario/core/strategy_graph.py @@ -9,15 +9,17 @@ returns a ``(next_state, ScenarioStepResult | None)`` pair; the graph yields the result and advances. Terminal states stop the loop. -This module is part of the scenario core refactor scaffold (Phase 0). The -skeleton supports straight-line execution; richer policy composition (graph -validation, cycle detection, parallel branches) lands as Phase 3 needs it. +This module also exposes ``linear_strategy_policy``, a convenience builder +that produces a trivial "run steps 0..N-1 in order" policy. Phase 5 will use +it to silently upgrade scenarios that still declare their steps as a flat +list (via the legacy ``_get_atomic_attacks_async`` override) without forcing +those scenarios to author a custom policy. """ from __future__ import annotations import logging -from collections.abc import AsyncIterator, Awaitable, Callable, Hashable, Mapping +from collections.abc import AsyncIterator, Awaitable, Callable, Hashable, Mapping, Sequence from dataclasses import dataclass, field from types import MappingProxyType from typing import TYPE_CHECKING, Generic, TypeVar @@ -239,8 +241,60 @@ async def event_loop_async(self) -> AsyncIterator[ScenarioStepResult]: self._current_state = next_state +def linear_strategy_policy( + steps: Sequence[ScenarioStep], +) -> StrategyPolicy[ScenarioStep, int]: + """ + Build a trivial linear-traversal policy over an ordered list of steps. + + State ``i`` binds ``steps[i]`` as the graph's current step, awaits its + ``process_async``, and transitions to state ``i + 1``. State ``len(steps)`` + is the sole terminal state. + + Args: + steps (Sequence[ScenarioStep]): Steps to execute in order. Must be + non-empty. + + Returns: + StrategyPolicy[ScenarioStep, int]: A policy that walks the steps + sequentially. Pass it to ``StrategyGraph(policy=...)`` to get the + runnable graph. + + Raises: + ValueError: If ``steps`` is empty. + """ + if not steps: + raise ValueError("linear_strategy_policy requires at least one step.") + + terminal = len(steps) + actions: dict[int, PolicyAction[ScenarioStep, int]] = {} + + for index, step in enumerate(steps): + + async def _action( + graph: StrategyGraph[ScenarioStep, int], + _step: ScenarioStep = step, + _next: int = index + 1, + ) -> tuple[int, ScenarioStepResult | None]: + graph.bind_current_step(step=_step) + try: + result = await _step.process_async() + finally: + graph.bind_current_step(step=None) + return _next, result + + actions[index] = _action + + return StrategyPolicy( + actions=actions, + initial_state=0, + terminal_states=frozenset({terminal}), + ) + + __all__ = [ "StrategyGraph", "StrategyPolicy", "PolicyAction", + "linear_strategy_policy", ] diff --git a/tests/unit/scenario/test_linear_strategy_policy.py b/tests/unit/scenario/test_linear_strategy_policy.py new file mode 100644 index 000000000..4f9a631c4 --- /dev/null +++ b/tests/unit/scenario/test_linear_strategy_policy.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the ``linear_strategy_policy`` convenience builder. + +These tests double as the silent-upgrade contract for Phase 5: the legacy +flat ``_get_atomic_attacks_async`` -> ``list[AtomicAttack]`` contract becomes +a ``StrategyGraph`` driven by ``linear_strategy_policy``. If anything here +breaks, the silent-upgrade path is broken. +""" + +from typing import Any + +import pytest + +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.strategy_graph import StrategyGraph, linear_strategy_policy + + +class _RecordingStep(ScenarioStep): + """Minimal ``ScenarioStep`` that records each invocation.""" + + def __init__(self, *, name: str, outcome: str = "done"): + self._name = name + self._outcome = outcome + self.call_count = 0 + + @property + def name(self) -> str: # type: ignore[override] + return self._name + + @property + def outputs(self) -> list[str]: # type: ignore[override] + return [self._outcome] + + async def process_async(self) -> ScenarioStepResult: + self.call_count += 1 + return ScenarioStepResult( + outcome=self._outcome, + metadata={"name": self._name, "call": self.call_count}, + ) + + +def test_empty_steps_raises(): + with pytest.raises(ValueError, match="at least one step"): + linear_strategy_policy([]) + + +def test_single_step_policy_has_one_terminal_state(): + step = _RecordingStep(name="solo") + policy = linear_strategy_policy([step]) + + assert policy.initial_state == 0 + assert policy.terminal_states == frozenset({1}) + assert policy.is_terminal(state=1) + assert not policy.is_terminal(state=0) + + +async def test_linear_policy_executes_steps_in_order(): + steps = [_RecordingStep(name=f"step_{i}") for i in range(3)] + graph: StrategyGraph[ScenarioStep, int] = StrategyGraph(policy=linear_strategy_policy(steps)) + + results = [r async for r in graph.event_loop_async()] + + assert len(results) == 3 + assert [r.metadata["name"] for r in results] == ["step_0", "step_1", "step_2"] + assert [s.call_count for s in steps] == [1, 1, 1] + assert graph.current_state == 3 + assert graph.is_terminal + + +async def test_linear_policy_binds_current_step_during_execution(): + """The action must bind the step before invoking it and clear after.""" + binding_history: list[Any] = [] + + class _BindingObserverStep(_RecordingStep): + def __init__(self, *, name: str, graph_ref: list[Any]): + super().__init__(name=name) + self._graph_ref = graph_ref + + async def process_async(self) -> ScenarioStepResult: + graph = self._graph_ref[0] + binding_history.append(graph.current_step) + return await super().process_async() + + graph_ref: list[Any] = [None] + steps = [_BindingObserverStep(name="a", graph_ref=graph_ref), _BindingObserverStep(name="b", graph_ref=graph_ref)] + graph: StrategyGraph[ScenarioStep, int] = StrategyGraph(policy=linear_strategy_policy(steps)) + graph_ref[0] = graph + + _ = [r async for r in graph.event_loop_async()] + + # During each step's process_async, current_step should be that step. + assert binding_history == steps + # After traversal completes, current_step is cleared. + assert graph.current_step is None + + +async def test_linear_policy_clears_current_step_on_failure(): + """If process_async raises, the bound step is still cleared via finally.""" + + class _FailingStep(_RecordingStep): + async def process_async(self) -> ScenarioStepResult: + raise RuntimeError("boom") + + steps: list[ScenarioStep] = [_FailingStep(name="fails")] + graph: StrategyGraph[ScenarioStep, int] = StrategyGraph(policy=linear_strategy_policy(steps)) + + with pytest.raises(RuntimeError, match="boom"): + _ = [r async for r in graph.event_loop_async()] + + assert graph.current_step is None + # Graph stayed at initial state — retry can re-enter. + assert graph.current_state == 0 + + +async def test_linear_policy_each_action_runs_its_own_step(): + """Late binding bug guard: every action must run its own indexed step.""" + steps = [_RecordingStep(name=f"step_{i}", outcome=f"out_{i}") for i in range(4)] + graph: StrategyGraph[ScenarioStep, int] = StrategyGraph(policy=linear_strategy_policy(steps)) + + results = [r async for r in graph.event_loop_async()] + assert [r.outcome for r in results] == ["out_0", "out_1", "out_2", "out_3"] diff --git a/tests/unit/scenario/test_strategy_graph_branching.py b/tests/unit/scenario/test_strategy_graph_branching.py new file mode 100644 index 000000000..2d98cad29 --- /dev/null +++ b/tests/unit/scenario/test_strategy_graph_branching.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Branching ``StrategyGraph`` integration test. + +The rest of Phase 0/3's coverage exercises the policy machinery on toy +two-step flat graphs. This module forces the API through a non-trivial +branching scenario *before* Phase 5 commits ``Scenario.run_async`` to it. + +Scenario: ``BroadSweepThenDeepDive`` + OPENING_PHASE + runs a broad-sweep step that emits one of: + ``"safe"`` — every category passed, transition to COMPLETE + ``"violation"`` — at least one category leaked, transition to + ESCALATION_PHASE + ESCALATION_PHASE + runs a deep-dive step that always emits ``"done"`` and transitions + to COMPLETE. + +This is intentionally not a real scenario (lives in tests/), but it has the +exact shape the first real graph-based scenarios will need: outcome-driven +branching with at least one self-loop-free path that skips a step. +""" + +from enum import Enum + +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy + + +class _Phase(str, Enum): + OPENING = "opening" + ESCALATION = "escalation" + COMPLETE = "complete" + + +class _SweepStep(ScenarioStep): + """Broad-sweep step. Emits ``"safe"`` or ``"violation"`` per fixture state.""" + + def __init__(self, *, sweep_outcome: str): + self._sweep_outcome = sweep_outcome + self.call_count = 0 + + @property + def name(self) -> str: # type: ignore[override] + return "broad_sweep" + + @property + def outputs(self) -> list[str]: # type: ignore[override] + return ["safe", "violation"] + + async def process_async(self) -> ScenarioStepResult: + self.call_count += 1 + return ScenarioStepResult(outcome=self._sweep_outcome) + + +class _EscalationStep(ScenarioStep): + """Deep-dive step. Only runs if the sweep emitted a violation.""" + + def __init__(self) -> None: + self.call_count = 0 + + @property + def name(self) -> str: # type: ignore[override] + return "deep_dive" + + @property + def outputs(self) -> list[str]: # type: ignore[override] + return ["done"] + + async def process_async(self) -> ScenarioStepResult: + self.call_count += 1 + return ScenarioStepResult(outcome="done", metadata={"escalated": True}) + + +def _build_graph(*, sweep_outcome: str) -> tuple[StrategyGraph, _SweepStep, _EscalationStep]: + sweep = _SweepStep(sweep_outcome=sweep_outcome) + escalation = _EscalationStep() + + async def opening_action(graph): + graph.bind_current_step(step=sweep) + try: + result = await sweep.process_async() + finally: + graph.bind_current_step(step=None) + next_state = _Phase.COMPLETE if result.outcome == "safe" else _Phase.ESCALATION + return next_state, result + + async def escalation_action(graph): + graph.bind_current_step(step=escalation) + try: + result = await escalation.process_async() + finally: + graph.bind_current_step(step=None) + return _Phase.COMPLETE, result + + policy: StrategyPolicy[ScenarioStep, _Phase] = StrategyPolicy( + actions={_Phase.OPENING: opening_action, _Phase.ESCALATION: escalation_action}, + initial_state=_Phase.OPENING, + terminal_states=frozenset({_Phase.COMPLETE}), + ) + return StrategyGraph(policy=policy), sweep, escalation + + +async def test_safe_sweep_short_circuits_to_complete(): + graph, sweep, escalation = _build_graph(sweep_outcome="safe") + + results = [r async for r in graph.event_loop_async()] + outcomes = [r.outcome for r in results] + + assert outcomes == ["safe"] + assert sweep.call_count == 1 + assert escalation.call_count == 0 + assert graph.current_state == _Phase.COMPLETE + assert graph.is_terminal + + +async def test_violation_sweep_triggers_escalation(): + graph, sweep, escalation = _build_graph(sweep_outcome="violation") + + results = [r async for r in graph.event_loop_async()] + outcomes = [r.outcome for r in results] + + assert outcomes == ["violation", "done"] + assert sweep.call_count == 1 + assert escalation.call_count == 1 + assert graph.current_state == _Phase.COMPLETE + assert graph.is_terminal + # The escalation step's metadata survived the round trip. + assert results[1].metadata["escalated"] is True + + +async def test_history_records_both_branches(): + graph, _, _ = _build_graph(sweep_outcome="violation") + _ = [r async for r in graph.event_loop_async()] + + states_before = [state for state, _ in graph.history] + assert states_before == [_Phase.OPENING, _Phase.ESCALATION] + + +async def test_reset_replays_branching_graph(): + """A graph that branched once can be reset and re-run on the other branch.""" + graph, sweep, escalation = _build_graph(sweep_outcome="violation") + _ = [r async for r in graph.event_loop_async()] + assert graph.current_state == _Phase.COMPLETE + + graph.reset() + assert graph.current_state == _Phase.OPENING + assert graph.history == [] + + # Re-run produces the same branch (action closures captured the same state). + results = [r async for r in graph.event_loop_async()] + assert [r.outcome for r in results] == ["violation", "done"] + assert sweep.call_count == 2 + assert escalation.call_count == 2 From b372d08878aaac55feb4f397db7d948abe8c051b Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 12:52:42 -0700 Subject: [PATCH 04/42] feat(scenario): add step_identifier persistence layer (Phase 4) Lands the additive `step_identifier` column on `AttackResultEntry` so `AttackResult` rows produced through the new `StrategyGraph` orchestrator carry the composite `ScenarioStep` identity built by `pyrit.identifiers.step_identifier.build_step_identifier` (introduced in Phase 0). Old rows stay null - no backfill, no destructive migration. Per the Phase 4 plan, `atomic_attack_identifier` is NOT renamed and NOT removed. `step_identifier` is purely additive metadata that records *which step inside which scenario* produced the attack result. Direct attack invocations continue to set only `atomic_attack_identifier` and write `step_identifier = null`. Changes: * pyrit/identifiers/evaluation_identifier.py - new `StepEvaluationIdentifier` mirroring `AtomicAttackEvaluationIdentifier.CHILD_EVAL_RULES` so nested attack-execution children get filtered identically inside step-level eval grouping. The step's own params (`step_name`, `outcome`, `eval_version`) are fully included - a `STEP_EVAL_VERSION` bump splits two semantically-equivalent step runs. * pyrit/identifiers/identifier_filters.py - `IdentifierType.STEP`. * pyrit/identifiers/__init__.py - exports `StepEvaluationIdentifier`. * pyrit/memory/alembic/versions/a1c2e4f80b3d_add_step_identifier.py - new migration chaining off `7a1b2c3d4e5f` adding a nullable JSON column. * pyrit/memory/memory_models.py - `AttackResultEntry.step_identifier` JSON column; `__init__` populates `eval_hash` via `StepEvaluationIdentifier` BEFORE the `to_dict` truncation pass so the hash survives DB storage, mirroring the atomic_attack_identifier precedent; `get_attack_result` reconstructs via `ComponentIdentifier.from_dict`. * pyrit/memory/memory_interface.py - `identifier_column_map` extended so `IdentifierType.STEP` filters route to the new column. * pyrit/models/attack_result.py - `step_identifier: Optional` field added to the dataclass + `to_dict` / `from_dict`. Old payloads without the key still hydrate cleanly. Tests (+18 new, all passing; full unit suite 7947 passed, 118 skipped, 1 pre-existing ODBC env failure): * test_step_evaluation_identifier.py - eval-hash stability, outcome / nested-target / eval_version sensitivity, scorer / operational-param exclusions, rule parity with AtomicAttackEvaluationIdentifier. * test_memory_models.py - AttackResultEntry round-trip with and without step_identifier, eval_hash preservation through the column. * test_attack_result.py - to_dict / from_dict round-trip; null behavior. * test_interface_attack_results.py - SQLite filter by `IdentifierType.STEP` matches step_name and skips legacy rows. * test_identifier_filters.py - guard test count + value assertion. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/identifiers/__init__.py | 2 + pyrit/identifiers/evaluation_identifier.py | 33 ++++ pyrit/identifiers/identifier_filters.py | 1 + .../a1c2e4f80b3d_add_step_identifier.py | 37 ++++ pyrit/memory/memory_interface.py | 5 +- pyrit/memory/memory_models.py | 18 ++ pyrit/models/attack_result.py | 12 ++ .../identifiers/test_identifier_filters.py | 3 +- .../test_step_evaluation_identifier.py | 177 ++++++++++++++++++ .../test_interface_attack_results.py | 60 ++++++ tests/unit/memory/test_memory_models.py | 58 ++++++ tests/unit/models/test_attack_result.py | 45 +++++ 12 files changed, 449 insertions(+), 2 deletions(-) create mode 100644 pyrit/memory/alembic/versions/a1c2e4f80b3d_add_step_identifier.py create mode 100644 tests/unit/identifiers/test_step_evaluation_identifier.py diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index c8e46b5ae..35df8b578 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -19,6 +19,7 @@ ChildEvalRule, EvaluationIdentifier, ScorerEvaluationIdentifier, + StepEvaluationIdentifier, compute_eval_hash, ) from pyrit.identifiers.identifier_filters import IdentifierFilter, IdentifierType @@ -39,6 +40,7 @@ "ScorerEvaluationIdentifier", "snake_case_to_class_name", "STEP_EVAL_VERSION", + "StepEvaluationIdentifier", "validate_registry_name", "config_hash", "IdentifierFilter", diff --git a/pyrit/identifiers/evaluation_identifier.py b/pyrit/identifiers/evaluation_identifier.py index 88a73d469..4cc9920d2 100644 --- a/pyrit/identifiers/evaluation_identifier.py +++ b/pyrit/identifiers/evaluation_identifier.py @@ -276,3 +276,36 @@ class AtomicAttackEvaluationIdentifier(EvaluationIdentifier): # attack_technique: not listed in rules — fully included in eval hash. # technique_seeds (nested inside attack_technique): also not listed — fully included. } + + +class StepEvaluationIdentifier(EvaluationIdentifier): + """ + Evaluation identity for ``ScenarioStep`` executions. + + A step identifier wraps one or more ``atomic_attack_identifier`` children + under ``attack_executions``; this class reuses + ``AtomicAttackEvaluationIdentifier.CHILD_EVAL_RULES`` for those nested + attack-execution children so per-attack eval semantics are preserved + inside step-level eval grouping. + + The step's own ``params`` (``step_name``, ``outcome``, ``eval_version``) + are fully included so two semantically-equivalent step runs with the same + name and outcome land in the same eval group, but a schema bump + (``eval_version``) splits them — matching the additive contract spelled + out at the top of ``pyrit.identifiers.step_identifier``. + """ + + CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] = { + # Mirror the per-attack rules so each nested atomic_attack_identifier + # is filtered exactly the way AtomicAttackEvaluationIdentifier would + # filter it on its own. + "objective_target": ChildEvalRule( + included_params=frozenset({"temperature"}), + ), + "adversarial_chat": ChildEvalRule( + included_params=frozenset({"underlying_model_name", "temperature", "top_p"}), + param_fallbacks={"underlying_model_name": "model_name"}, + ), + "objective_scorer": ChildEvalRule(exclude=True), + "seed_identifiers": ChildEvalRule(exclude=True), + } diff --git a/pyrit/identifiers/identifier_filters.py b/pyrit/identifiers/identifier_filters.py index bd217e4a0..7c0c8d071 100644 --- a/pyrit/identifiers/identifier_filters.py +++ b/pyrit/identifiers/identifier_filters.py @@ -9,6 +9,7 @@ class IdentifierType(Enum): """Enumeration of supported identifier types for filtering.""" ATTACK = "attack" + STEP = "step" TARGET = "target" SCORER = "scorer" CONVERTER = "converter" diff --git a/pyrit/memory/alembic/versions/a1c2e4f80b3d_add_step_identifier.py b/pyrit/memory/alembic/versions/a1c2e4f80b3d_add_step_identifier.py new file mode 100644 index 000000000..97276f536 --- /dev/null +++ b/pyrit/memory/alembic/versions/a1c2e4f80b3d_add_step_identifier.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +add step_identifier to AttackResultEntries. + +Revision ID: a1c2e4f80b3d +Revises: 7a1b2c3d4e5f +Create Date: 2026-05-20 12:00:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1c2e4f80b3d" +down_revision: str | None = "7a1b2c3d4e5f" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # Additive nullable column: scenarios that opt into StrategyGraph populate + # this with the composite ScenarioStep identifier; legacy and direct-attack + # rows leave it null. No backfill needed. + op.add_column( + "AttackResultEntries", + sa.Column("step_identifier", sa.JSON(), nullable=True), + ) + + +def downgrade() -> None: + """Revert this schema upgrade.""" + op.drop_column("AttackResultEntries", "step_identifier") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 65d4480aa..552544c70 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1863,7 +1863,10 @@ def get_attack_results( conditions.extend( self._build_identifier_filter_conditions( identifier_filters=identifier_filters, - identifier_column_map={IdentifierType.ATTACK: AttackResultEntry.atomic_attack_identifier}, + identifier_column_map={ + IdentifierType.ATTACK: AttackResultEntry.atomic_attack_identifier, + IdentifierType.STEP: AttackResultEntry.step_identifier, + }, caller="get_attack_results", ) ) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index efb9ec72a..33d475015 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -35,6 +35,7 @@ from pyrit.identifiers.evaluation_identifier import ( AtomicAttackEvaluationIdentifier, ScorerEvaluationIdentifier, + StepEvaluationIdentifier, ) from pyrit.models import ( AttackOutcome, @@ -710,6 +711,7 @@ class AttackResultEntry(Base): objective = mapped_column(Unicode, nullable=False) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) atomic_attack_identifier: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + step_identifier: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) objective_sha256 = mapped_column(String, nullable=True) last_response_id: Mapped[uuid.UUID | None] = mapped_column( CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), nullable=True @@ -778,6 +780,19 @@ def __init__(self, *, entry: AttackResult) -> None: if entry.atomic_attack_identifier else None ) + # Ensure eval_hash is set on the step identifier so it survives the DB + # round-trip the same way atomic_attack_identifier does above. + if entry.step_identifier and entry.step_identifier.eval_hash is None: + entry.step_identifier = entry.step_identifier.with_eval_hash( + StepEvaluationIdentifier(entry.step_identifier).eval_hash + ) + self.step_identifier = ( + entry.step_identifier.to_dict( + max_value_length=MAX_IDENTIFIER_VALUE_LENGTH, + ) + if entry.step_identifier + else None + ) self.objective_sha256 = to_sha256(entry.objective) # Use helper method for UUID conversions @@ -900,6 +915,8 @@ def get_attack_result(self) -> AttackResult: attack_identifier=ComponentIdentifier.from_dict(self.attack_identifier), ) + step_id = ComponentIdentifier.from_dict(self.step_identifier) if self.step_identifier else None + # Deserialize retry events from JSON retry_events = [] if self.retry_events_json: @@ -912,6 +929,7 @@ def get_attack_result(self) -> AttackResult: attack_result_id=str(self.id), objective=self.objective, atomic_attack_identifier=atomic_id, + step_identifier=step_id, last_response=self.last_response.get_message_piece() if self.last_response else None, last_score=self.last_score.get_score() if self.last_score else None, executed_turns=self.executed_turns, diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 703a2b90a..ea388e8de 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -63,6 +63,12 @@ class AttackResult(StrategyResult): # Contains the attack strategy as children["attack"] plus optional seeds. atomic_attack_identifier: Optional[ComponentIdentifier] = None + # Composite identifier for the ScenarioStep execution that produced this + # result. Populated only for scenarios that run through StrategyGraph; + # ``None`` for legacy flat scenarios and for any direct attack invocation + # outside the scenario layer. Built via ``build_step_identifier``. + step_identifier: Optional[ComponentIdentifier] = None + # Evidence # Model response generated in the final turn of the attack last_response: Optional[MessagePiece] = None @@ -234,6 +240,7 @@ def to_dict(self) -> dict[str, Any]: "atomic_attack_identifier": ( self.atomic_attack_identifier.to_dict() if self.atomic_attack_identifier else None ), + "step_identifier": (self.step_identifier.to_dict() if self.step_identifier else None), "last_response": self.last_response.to_dict() if self.last_response else None, "last_score": self.last_score.to_dict() if self.last_score else None, "executed_turns": self.executed_turns, @@ -274,6 +281,11 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult: if data.get("atomic_attack_identifier") else None ), + step_identifier=( + ComponentIdentifier.from_dict(data["step_identifier"]) + if data.get("step_identifier") + else None + ), last_response=(MessagePiece.from_dict(data["last_response"]) if data.get("last_response") else None), last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, executed_turns=data.get("executed_turns", 0), diff --git a/tests/unit/identifiers/test_identifier_filters.py b/tests/unit/identifiers/test_identifier_filters.py index 7e66ba95d..1aaba4804 100644 --- a/tests/unit/identifiers/test_identifier_filters.py +++ b/tests/unit/identifiers/test_identifier_filters.py @@ -13,10 +13,11 @@ def test_identifier_type_values(): assert IdentifierType.TARGET.value == "target" assert IdentifierType.SCORER.value == "scorer" assert IdentifierType.CONVERTER.value == "converter" + assert IdentifierType.STEP.value == "step" def test_identifier_type_member_count(): - assert len(IdentifierType) == 4 + assert len(IdentifierType) == 5 # --- IdentifierFilter creation --- diff --git a/tests/unit/identifiers/test_step_evaluation_identifier.py b/tests/unit/identifiers/test_step_evaluation_identifier.py new file mode 100644 index 000000000..082adbdc1 --- /dev/null +++ b/tests/unit/identifiers/test_step_evaluation_identifier.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for ``StepEvaluationIdentifier``. + +Mirrors ``test_atomic_attack_identifier.py::TestAtomicAttackEvaluationIdentifier`` +for the step layer that wraps one or more atomic attack identifiers under a +single ``attack_executions`` children entry. +""" + +from pyrit.identifiers import ( + AtomicAttackEvaluationIdentifier, + ComponentIdentifier, + StepEvaluationIdentifier, + build_atomic_attack_identifier, +) +from pyrit.identifiers.step_identifier import ( + STEP_EVAL_VERSION, + build_step_identifier, +) + +_ATTACK_MODULE = "pyrit.executor.attack.single_turn.prompt_sending" +_TARGET_MODULE = "pyrit.prompt_target.openai.openai_chat_target" + + +def _make_target(*, params: dict | None = None) -> ComponentIdentifier: + return ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module=_TARGET_MODULE, + params=params or {}, + ) + + +def _build_step( + *, + outcome: str = "done", + target_temp: float = 0.7, + objective_scorer: ComponentIdentifier | None = None, +) -> ComponentIdentifier: + attack_children: dict = {"objective_target": [_make_target(params={"temperature": target_temp})]} + if objective_scorer is not None: + attack_children["objective_scorer"] = [objective_scorer] + + attack = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module=_ATTACK_MODULE, + children=attack_children, + ) + atomic = build_atomic_attack_identifier(attack_identifier=attack) + return build_step_identifier( + step_name="opening_phase", + outcome=outcome, + attack_execution_identifiers=[atomic], + ) + + +class TestStepEvaluationIdentifier: + """Behavior of the eval-hash wrapper for step identifiers.""" + + def test_eval_hash_is_64_char_hex(self): + ident = _build_step() + eval_hash = StepEvaluationIdentifier(ident).eval_hash + assert len(eval_hash) == 64 + int(eval_hash, 16) + + def test_identifier_property_returns_original(self): + ident = _build_step() + wrapper = StepEvaluationIdentifier(ident) + assert wrapper.identifier is ident + + def test_preserved_eval_hash_from_round_trip(self): + # Once an eval_hash is stamped on the identifier (DB round-trip), + # the wrapper trusts it rather than recomputing. + ident = _build_step() + computed = StepEvaluationIdentifier(ident).eval_hash + stamped = ComponentIdentifier( + class_name=ident.class_name, + class_module=ident.class_module, + params=dict(ident.params), + children=dict(ident.children), + eval_hash=computed, + ) + assert StepEvaluationIdentifier(stamped).eval_hash == computed + + def test_same_outcome_same_eval_hash(self): + a = StepEvaluationIdentifier(_build_step(outcome="done")).eval_hash + b = StepEvaluationIdentifier(_build_step(outcome="done")).eval_hash + assert a == b + + def test_different_outcome_different_eval_hash(self): + done = StepEvaluationIdentifier(_build_step(outcome="done")).eval_hash + violation = StepEvaluationIdentifier(_build_step(outcome="safety_violation")).eval_hash + assert done != violation + + def test_nested_objective_target_operational_params_ignored(self): + # Mirror AtomicAttackEvaluationIdentifier: only temperature matters + # on objective_target; endpoint / max_requests_per_minute should not + # contribute to the step eval hash either. + baseline = _build_step() + + noisy_target = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module=_TARGET_MODULE, + params={ + "temperature": 0.7, + "endpoint": "https://noisy.example.com", + "max_requests_per_minute": 999, + }, + ) + noisy_attack = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module=_ATTACK_MODULE, + children={"objective_target": [noisy_target]}, + ) + noisy_atomic = build_atomic_attack_identifier(attack_identifier=noisy_attack) + noisy = build_step_identifier( + step_name="opening_phase", + outcome="done", + attack_execution_identifiers=[noisy_atomic], + ) + assert ( + StepEvaluationIdentifier(noisy).eval_hash + == StepEvaluationIdentifier(baseline).eval_hash + ) + + def test_nested_objective_target_temperature_change_changes_hash(self): + hot = StepEvaluationIdentifier(_build_step(target_temp=0.7)).eval_hash + cool = StepEvaluationIdentifier(_build_step(target_temp=0.2)).eval_hash + assert hot != cool + + def test_nested_objective_scorer_excluded(self): + # objective_scorer should be excluded from the step eval hash, exactly + # as it is from the atomic eval hash. + without_scorer = StepEvaluationIdentifier(_build_step()).eval_hash + scorer = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score.true_false.self_ask_true_false_scorer", + params={"true_false_question": "Is the response harmful?"}, + ) + with_scorer = StepEvaluationIdentifier(_build_step(objective_scorer=scorer)).eval_hash + assert without_scorer == with_scorer + + def test_eval_version_bump_changes_hash(self, monkeypatch): + original = StepEvaluationIdentifier(_build_step()).eval_hash + monkeypatch.setattr("pyrit.identifiers.step_identifier.STEP_EVAL_VERSION", STEP_EVAL_VERSION + 1) + bumped = StepEvaluationIdentifier(_build_step()).eval_hash + assert original != bumped + + def test_step_name_change_changes_hash(self): + opening = StepEvaluationIdentifier(_build_step(outcome="done")).eval_hash + escalation = StepEvaluationIdentifier( + build_step_identifier( + step_name="escalation_phase", + outcome="done", + attack_execution_identifiers=[ + build_atomic_attack_identifier( + attack_identifier=ComponentIdentifier( + class_name="PromptSendingAttack", + class_module=_ATTACK_MODULE, + children={"objective_target": [_make_target(params={"temperature": 0.7})]}, + ) + ) + ], + ) + ).eval_hash + assert opening != escalation + + def test_mirrors_atomic_rules_at_step_level(self): + # StepEvaluationIdentifier reuses the same child-name rules as + # AtomicAttackEvaluationIdentifier so nested attack children get + # filtered identically. + atomic_rules = AtomicAttackEvaluationIdentifier.CHILD_EVAL_RULES + step_rules = StepEvaluationIdentifier.CHILD_EVAL_RULES + for name, rule in atomic_rules.items(): + assert name in step_rules + assert step_rules[name] == rule diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index d5729cf11..6cc6ddf97 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -1696,6 +1696,66 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance assert len(results) == 0 +def test_get_attack_results_by_step_identifier_filter_step_name(sqlite_instance: MemoryInterface): + """Filter attack results by step_identifier step_name (Phase 4).""" + from pyrit.identifiers.step_identifier import build_step_identifier + + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar1.step_identifier = build_step_identifier( + step_name="opening_phase", + outcome="done", + attack_execution_identifiers=[ar1.atomic_attack_identifier], + ) + ar2.step_identifier = build_step_identifier( + step_name="escalation_phase", + outcome="done", + attack_execution_identifiers=[ar2.atomic_attack_identifier], + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + results = sqlite_instance.get_attack_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.STEP, + property_path="$.step_name", + value="opening_phase", + partial_match=False, + ) + ], + ) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + assert results[0].step_identifier is not None + assert results[0].step_identifier.params["step_name"] == "opening_phase" + + +def test_get_attack_results_by_step_identifier_filter_skips_legacy_rows(sqlite_instance: MemoryInterface): + """Attack results without a step_identifier never match a STEP filter.""" + from pyrit.identifiers.step_identifier import build_step_identifier + + legacy_ar = _make_attack_result_with_identifier("conv_legacy", "CrescendoAttack") + new_ar = _make_attack_result_with_identifier("conv_new", "CrescendoAttack") + new_ar.step_identifier = build_step_identifier( + step_name="opening_phase", + outcome="done", + attack_execution_identifiers=[new_ar.atomic_attack_identifier], + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[legacy_ar, new_ar]) + + results = sqlite_instance.get_attack_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.STEP, + property_path="$.step_name", + value="opening_phase", + partial_match=False, + ) + ], + ) + assert [r.conversation_id for r in results] == ["conv_new"] + + def test_get_attack_results_targeted_harm_categories_emits_deprecation_warning(sqlite_instance: MemoryInterface): """Test that passing targeted_harm_categories emits a DeprecationWarning.""" import warnings diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index b382bf864..2c1a0a29f 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -406,6 +406,64 @@ def test_get_attack_result_prefers_atomic_over_stale_attack_identifier(self): assert strategy is not None assert strategy.class_name == "CorrectAttack" + # --- step_identifier (Phase 4) --- + + def test_step_identifier_round_trip(self): + """An AttackResult carrying a step_identifier survives DB persistence.""" + from pyrit.identifiers.step_identifier import build_step_identifier + + attack_id = ComponentIdentifier(class_name="PromptSending", class_module="pyrit.executor.attack") + atomic_id = build_atomic_attack_identifier(attack_identifier=attack_id) + step_id = build_step_identifier( + step_name="opening_phase", + outcome="safety_violation", + attack_execution_identifiers=[atomic_id], + ) + ar = _make_attack_result(atomic_attack_identifier=atomic_id, step_identifier=step_id) + entry = AttackResultEntry(entry=ar) + + # Column is populated as a serialized dict (flat params, per ComponentIdentifier.to_dict). + assert entry.step_identifier is not None + assert entry.step_identifier["class_name"] == "ScenarioStep" + assert entry.step_identifier["step_name"] == "opening_phase" + assert entry.step_identifier["outcome"] == "safety_violation" + + round_tripped = entry.get_attack_result() + assert round_tripped.step_identifier is not None + assert round_tripped.step_identifier.class_name == "ScenarioStep" + assert round_tripped.step_identifier.params["step_name"] == "opening_phase" + assert round_tripped.step_identifier.params["outcome"] == "safety_violation" + + def test_no_step_identifier_stays_none(self): + """Legacy results without a step_identifier remain None after round-trip.""" + ar = _make_attack_result() + assert ar.step_identifier is None + entry = AttackResultEntry(entry=ar) + assert entry.step_identifier is None + assert entry.get_attack_result().step_identifier is None + + def test_step_identifier_eval_hash_preserved(self): + """The step identifier's eval_hash is stamped on the column dict and survives a round-trip.""" + from pyrit.identifiers.step_identifier import build_step_identifier + + attack_id = ComponentIdentifier(class_name="PromptSending", class_module="pyrit.executor.attack") + atomic_id = build_atomic_attack_identifier(attack_identifier=attack_id) + step_id = build_step_identifier( + step_name="opening_phase", + outcome="done", + attack_execution_identifiers=[atomic_id], + ) + ar = _make_attack_result(atomic_attack_identifier=atomic_id, step_identifier=step_id) + entry = AttackResultEntry(entry=ar) + + assert entry.step_identifier is not None + assert entry.step_identifier.get("eval_hash") is not None + assert len(entry.step_identifier["eval_hash"]) == 64 + + round_tripped = entry.get_attack_result() + assert round_tripped.step_identifier is not None + assert round_tripped.step_identifier.eval_hash == entry.step_identifier["eval_hash"] + # --------------------------------------------------------------------------- # ScenarioResultEntry diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index a2db52f53..cfeb35d80 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -435,3 +435,48 @@ def test_to_dict_from_dict_roundtrip(): ) roundtripped = AttackResult.from_dict(original.to_dict()) assert original.to_dict() == roundtripped.to_dict() + + +def test_to_dict_from_dict_roundtrip_with_step_identifier(): + """AttackResult round-trips its step_identifier (Phase 4).""" + from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier + from pyrit.identifiers.step_identifier import build_step_identifier + + attack_id = ComponentIdentifier(class_name="PromptSendingAttack", class_module="pyrit.executor.attack") + atomic_id = build_atomic_attack_identifier(attack_identifier=attack_id) + step_id = build_step_identifier( + step_name="opening_phase", + outcome="safety_violation", + attack_execution_identifiers=[atomic_id], + ) + original = AttackResult( + conversation_id="conv-1", + objective="Generate harmful content", + atomic_attack_identifier=atomic_id, + step_identifier=step_id, + executed_turns=1, + execution_time_ms=100, + outcome=AttackOutcome.SUCCESS, + outcome_reason="achieved", + ) + roundtripped = AttackResult.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() + assert roundtripped.step_identifier is not None + assert roundtripped.step_identifier.class_name == "ScenarioStep" + assert roundtripped.step_identifier.params["step_name"] == "opening_phase" + assert roundtripped.step_identifier.params["outcome"] == "safety_violation" + + +def test_to_dict_omits_step_identifier_when_none(): + """When step_identifier is not set, to_dict carries it as None (additive, never absent).""" + original = AttackResult( + conversation_id="conv-1", + objective="test", + executed_turns=1, + execution_time_ms=10, + outcome=AttackOutcome.SUCCESS, + ) + assert original.step_identifier is None + serialized = original.to_dict() + assert "step_identifier" in serialized + assert serialized["step_identifier"] is None From 952311d9cffdcae8dc65bad917217bc1262ae23d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:12:32 -0700 Subject: [PATCH 05/42] MAINT: rewire Scenario.run_async through StrategyGraph (Phase 5) Phase 5 of the scenario-core refactor moves Scenario.run_async from a flat for-loop over AtomicAttacks to a StrategyGraph event loop, without changing observable behavior for any existing scenario. Key changes in pyrit/scenario/core/scenario.py: * New `_build_execution_graph(*, steps=None)` factory returns the StrategyGraph that drives the execution attempt. Default implementation wraps the supplied steps (or self._atomic_attacks) via `_build_default_linear_policy`, which preserves AtomicAttack-level concurrency semantics (max_concurrency, return_partial_on_failure) and stamps each step's name into ScenarioStepResult.metadata['step_name'] so the orchestrator can identify yields without depending on graph.current_step. * `_execute_scenario_async` now iterates `self._execution_graph.event_loop_async()` instead of the flat remaining_attacks list. Resume-by-name semantics are preserved: `_get_remaining_atomic_attacks_async` runs first, the graph is built from its output, and already-completed steps simply aren't in the policy. Partial-failure handling, retry, scenario_run_state transitions, error_attack_result_ids persistence, and progress-bar continuity all behave identically. * Each AttackResult flowing out of the graph is stamped with a step_identifier (the Phase 4 column) and that identifier is pushed to the existing AttackResultEntry row via update_attack_result_by_id, mirroring AtomicAttack._enrich_atomic_attack_identifiers. Steps that pre-stamp their own step_identifier (e.g., future adaptive steps) are not overwritten. * New public properties `execution_graph` and `execution_history` expose the active attempt's state machine for inspection and downstream tooling. Tests: * New tests/unit/scenario/test_scenario_graph_execution.py (11 tests) pins the new public surface: graph factory contract, execution_graph/execution_history properties, step_identifier stamping (default and pre-stamped), max_concurrency propagation, partial-failure surfacing, and non-AtomicAttack ScenarioStep dispatch through process_async via subclass override. * Full unit suite: 7958 passed, 118 skipped, 1 pre-existing ODBC env failure unrelated to the refactor. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/core/scenario.py | 395 +++++++++++++---- .../scenario/test_scenario_graph_execution.py | 401 ++++++++++++++++++ 2 files changed, 714 insertions(+), 82 deletions(-) create mode 100644 tests/unit/scenario/test_scenario_graph_execution.py diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index b82f16af9..2897421a0 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -36,8 +36,14 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_scorer_target +from pyrit.scenario.core.strategy_graph import ( + PolicyAction, + StrategyGraph, + StrategyPolicy, +) from pyrit.score import ( Scorer, SelfAskRefusalScorer, @@ -238,6 +244,15 @@ def __init__( # before _get_atomic_attacks_async is awaited so overrides can read it. self._include_baseline: bool = False + # Phase 5: state-machine view over the scenario's steps. Built lazily in + # _execute_scenario_async from self._atomic_attacks after the resume filter + # has been applied. Stays None until the first execution attempt. + # The default ``_build_execution_graph`` uses ``int`` as the state type; we + # store as ``StrategyGraph[ScenarioStep, Any]`` so subclasses that override + # the builder with a string/Enum state type can stash their graph here too + # without invariance fights. + self._execution_graph: Optional[StrategyGraph[ScenarioStep, Any]] = None + # Deprecated constructor-time baseline override. Will be removed in 0.16.0, along # with the include_default_baseline kwarg above and the legacy fallback branch in # initialize_async. Subclass shims set this attribute directly to avoid double-warning. @@ -260,6 +275,35 @@ def atomic_attack_count(self) -> int: """Get the number of atomic attacks in this scenario.""" return len(self._atomic_attacks) + @property + def execution_graph(self) -> Optional[StrategyGraph[ScenarioStep, Any]]: + """ + The ``StrategyGraph`` driving this scenario's current execution attempt. + + Built in ``_execute_scenario_async`` from the steps that remain after the + resume filter; ``None`` before the first call to ``run_async`` (or any + time outside of an active execution attempt). + + Subclasses can override ``_build_execution_graph`` to declare a richer + state-machine policy; the default uses ``_build_default_linear_policy`` + to wrap ``self._atomic_attacks`` in a linear traversal that matches the + legacy flat ``_execute_scenario_async`` loop exactly. + """ + return self._execution_graph + + @property + def execution_history(self) -> list[ScenarioStepResult]: + """ + Ordered list of step results produced by the current execution attempt. + + Empty when no graph has been built yet, or between retry attempts that + reset the graph. Each entry is the ``ScenarioStepResult`` yielded by a + step's policy action, in execution order. + """ + if self._execution_graph is None: + return [] + return [result for _, result in self._execution_graph.history] + @classmethod @abstractmethod def get_strategy_class(cls) -> type[ScenarioStrategy]: @@ -1043,28 +1087,161 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: return atomic_attacks - async def run_async(self) -> ScenarioResult: + def _build_execution_graph( + self, *, steps: Optional[Sequence[ScenarioStep]] = None + ) -> StrategyGraph[ScenarioStep, int]: """ - Execute all atomic attacks in the scenario sequentially. + Build the ``StrategyGraph`` that drives this execution attempt. + + Default implementation wraps the supplied ``steps`` (or, if omitted, + ``self._atomic_attacks``) in a linear policy via + ``_build_default_linear_policy``. This produces a graph whose traversal + is identical to the legacy flat ``_execute_scenario_async`` loop, so + scenarios that haven't opted into a richer policy see no behavior + change. - Each AtomicAttack is executed in order, and all results are aggregated - into a ScenarioResult containing the scenario metadata and all attack results. - This method supports resumption - if the scenario raises an exception partway through, - calling run_async again will skip already-completed objectives. + Subclasses with a state-machine flavor (rapid-response, adaptive, + branching) override this to author their own ``StrategyPolicy`` and + pass it to ``StrategyGraph``. Such overrides should still consume + ``self._atomic_attacks`` as the seed of their step inventory so the + existing resume-by-name path keeps working through Phase 5. - If max_retries is set, the scenario will automatically retry after an exception up to - the specified number of times. Each retry will resume from where it left off, - skipping completed objectives. + Args: + steps (Optional[Sequence[ScenarioStep]]): Steps to drive. ``None`` + falls back to ``self._atomic_attacks``. ``_execute_scenario_async`` + passes the resume-filtered list explicitly so already-completed + steps are not re-executed. Returns: - ScenarioResult: Contains scenario identifier and aggregated list of all - attack results from all atomic attacks. + StrategyGraph[ScenarioStep, int]: The graph that ``run_async`` + will iterate. Raises: - ValueError: If the scenario has no atomic attacks configured. If your scenario - requires initialization, call await scenario.initialize() first. - ValueError: If the scenario raises an exception after exhausting all retry attempts. - RuntimeError: If the scenario fails for any other reason while executing. + ValueError: If ``steps`` is empty (or unset and there are no + atomic attacks). + """ + effective_steps = list(steps) if steps is not None else list(self._atomic_attacks) + if not effective_steps: + raise ValueError( + "Cannot build an execution graph with no steps. Either initialize the " + "scenario via ``await scenario.initialize_async(...)`` so atomic attacks are " + "populated, or override ``_build_execution_graph`` to supply your own steps." + ) + return StrategyGraph(policy=self._build_default_linear_policy(steps=effective_steps)) + + def _build_default_linear_policy( + self, *, steps: Sequence[ScenarioStep] + ) -> StrategyPolicy[ScenarioStep, int]: + """ + Build a linear-traversal policy that preserves scenario-level execution params. + + Each policy action runs ``steps[i]`` and transitions to state ``i + 1``; + state ``len(steps)`` is the sole terminal state. For ``AtomicAttack`` + steps the action calls ``run_async`` directly so ``max_concurrency`` and + ``return_partial_on_failure`` semantics that the legacy flat loop relied + on are preserved. Non-``AtomicAttack`` steps fall back to + ``process_async`` (so any future custom ``ScenarioStep`` subclass works + out of the box). In both paths the step's ``name`` is stamped into + ``ScenarioStepResult.metadata['step_name']`` so the orchestrator can + identify the step at yield time. + + Args: + steps (Sequence[ScenarioStep]): The steps to wrap. Must be non-empty. + + Returns: + StrategyPolicy[ScenarioStep, int]: A frozen linear policy. + + Raises: + ValueError: If ``steps`` is empty. + """ + if not steps: + raise ValueError("_build_default_linear_policy requires at least one step.") + + max_concurrency = self._max_concurrency + terminal_state = len(steps) + actions: dict[int, PolicyAction[ScenarioStep, int]] = {} + + for index, step in enumerate(steps): + + async def _action( + graph: StrategyGraph[ScenarioStep, int], + _step: ScenarioStep = step, + _next: int = index + 1, + _max_concurrency: int = max_concurrency, + ) -> tuple[int, ScenarioStepResult | None]: + graph.bind_current_step(step=_step) + try: + if isinstance(_step, AtomicAttack): + executor_result = await _step.run_async( + max_concurrency=_max_concurrency, + return_partial_on_failure=True, + ) + result: ScenarioStepResult | None = ScenarioStepResult( + outcome="done", + attack_results=list(executor_result.completed_results), + metadata={ + "step_name": _step.atomic_attack_name, + "incomplete_objectives": list(executor_result.incomplete_objectives), + "input_indices": list(executor_result.input_indices), + }, + ) + else: + base_result = await _step.process_async() + # Re-stamp metadata with step_name so the orchestrator can route results + # without depending on graph.current_step (which is cleared before yield). + merged_metadata = {"step_name": _step.name, **base_result.metadata} + result = ScenarioStepResult( + outcome=base_result.outcome, + attack_results=base_result.attack_results, + step_identifier=base_result.step_identifier, + metadata=merged_metadata, + ) + finally: + graph.bind_current_step(step=None) + return _next, result + + actions[index] = _action + + return StrategyPolicy( + actions=actions, + initial_state=0, + terminal_states=frozenset({terminal_state}), + ) + + async def run_async(self) -> ScenarioResult: + """ + Execute the scenario by walking its ``StrategyGraph``. + + Each ``ScenarioStep`` produces a ``ScenarioStepResult`` whose attack + results are persisted in order and tagged with a ``step_identifier`` + so step-level filtering and grouping work alongside the existing + ``atomic_attack_identifier`` lineage. The default execution graph + produced by ``_build_execution_graph`` is a linear traversal of + ``self._atomic_attacks``, so scenarios that have not opted into a + richer policy see the same end-to-end behavior as before. + + The graph is rebuilt at the start of every execution attempt from the + resume-filtered step list, so calling ``run_async`` after a partial + failure skips already-completed work the same way the legacy flat + loop did. ``self.execution_graph`` and ``self.execution_history`` + expose the current attempt's state. + + If ``max_retries`` is set, the scenario will automatically retry after + an exception up to the specified number of times. Each retry rebuilds + the graph from the current remaining steps. + + Returns: + ScenarioResult: Contains scenario identifier and aggregated list of + attack results from every step that ran. + + Raises: + ValueError: If the scenario has no atomic attacks configured. If your + scenario requires initialization, call + ``await scenario.initialize_async()`` first. + ValueError: If the scenario raises an exception after exhausting all + retry attempts. + RuntimeError: If the scenario fails for any other reason while + executing. Example: >>> result = await scenario.run_async() @@ -1123,18 +1300,28 @@ async def _execute_scenario_async(self) -> ScenarioResult: """ Perform a single execution attempt of the scenario. - This method contains the core execution logic and can be called multiple times - for retry attempts. It increments the try counter, executes remaining atomic attacks, - and returns the scenario result. + Iterates ``self.execution_graph.event_loop_async()`` and applies the + same per-step persistence, partial-failure handling, and retry + semantics that the legacy flat loop applied per-``AtomicAttack``. The + graph is built once per execution attempt from the resume-filtered + ``self._atomic_attacks`` so already-completed steps are skipped. Returns: ScenarioResult: The result of this execution attempt. Raises: - Exception: Any exception that occurs during scenario execution. - ValueError: If a lookup for a scenario for a given ID fails. - ValueError: If atomic attack execution fails. + ValueError: If ``self._scenario_result_id`` is missing or any + step partially fails. + Exception: Any exception raised while executing a step is logged, + the scenario is marked ``FAILED``, and the exception is re-raised. """ + # Lazy import to avoid module-level circularity: build_step_identifier + # lives in pyrit.identifiers which itself imports several pyrit.models + # types that the scenario module re-exports indirectly. + from pyrit.identifiers.evaluation_identifier import StepEvaluationIdentifier + from pyrit.identifiers.step_identifier import build_step_identifier + from pyrit.memory.memory_models import MAX_IDENTIFIER_VALUE_LENGTH + logger.info(f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks") # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point @@ -1177,105 +1364,147 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Mark scenario as in progress self._memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="IN_PROGRESS") - # Calculate starting index based on completed attacks - completed_count = len(self._atomic_attacks) - len(remaining_attacks) + # Build a fresh execution graph from the resume-filtered steps for this attempt. + # We always rebuild on retry so the policy reflects the currently-pending work. + self._execution_graph = self._build_execution_graph(steps=remaining_attacks) + + # Calculate starting index based on completed attacks (for progress bar continuity). + total_steps = len(self._atomic_attacks) + completed_count = total_steps - len(remaining_attacks) + progress = tqdm( + desc=f"Executing {self._name}", + unit="attack", + total=total_steps, + initial=completed_count, + ) + step_position = completed_count + # Track the most recent step we attempted so a step-raised exception + # can still log the offending step's name. ``graph.current_step`` is + # cleared in the policy action's ``finally`` before the exception + # propagates, so it's not a reliable post-mortem source. + last_attempted_step_name: str = "" try: - for i, atomic_attack in enumerate( - tqdm( - remaining_attacks, - desc=f"Executing {self._name}", - unit="attack", - total=len(self._atomic_attacks), - initial=completed_count, - ), - start=completed_count + 1, - ): - logger.info( - f"Executing atomic attack {i}/{len(self._atomic_attacks)} " - f"('{atomic_attack.atomic_attack_name}') in scenario '{self._name}'" - ) + try: + async for step_result in self._execution_graph.event_loop_async(): + step_position += 1 + step_name = step_result.metadata.get("step_name", "") + last_attempted_step_name = step_name - try: - atomic_results = await atomic_attack.run_async( - max_concurrency=self._max_concurrency, - return_partial_on_failure=True, + logger.info( + f"Executing atomic attack {step_position}/{total_steps} " + f"('{step_name}') in scenario '{self._name}'" ) - # Always save completed results, even if some objectives didn't complete - if atomic_results.completed_results: + # Stamp step_identifier on every attack_result that doesn't already carry one. + # Steps may opt into setting it themselves (e.g., adaptive scenarios with + # nested attack executions); otherwise the default linear path stamps a + # one-attack-per-step composite identifier here. We mirror + # ``AtomicAttack._enrich_atomic_attack_identifiers``: populate the eval_hash + # before truncation so it survives the DB round-trip, then push the enriched + # identifier back to the AttackResultEntry row by attack_result_id. + for attack_result in step_result.attack_results: + if ( + attack_result.step_identifier is None + and attack_result.atomic_attack_identifier is not None + ): + new_identifier = build_step_identifier( + step_name=step_name, + outcome=step_result.outcome, + attack_execution_identifiers=[attack_result.atomic_attack_identifier], + ) + if new_identifier.eval_hash is None: + new_identifier = new_identifier.with_eval_hash( + StepEvaluationIdentifier(new_identifier).eval_hash + ) + attack_result.step_identifier = new_identifier + + # Push the (newly-stamped or pre-stamped) step_identifier to the existing + # AttackResultEntry so downstream ``get_scenario_results`` rehydrates it. + if attack_result.step_identifier is not None and attack_result.attack_result_id: + self._memory.update_attack_result_by_id( + attack_result_id=attack_result.attack_result_id, + update_fields={ + "step_identifier": attack_result.step_identifier.to_dict( + max_value_length=MAX_IDENTIFIER_VALUE_LENGTH, + ), + }, + ) + + # Always save completed results, even if some objectives didn't complete. + if step_result.attack_results: await self._update_scenario_result_async( - atomic_attack_name=atomic_attack.atomic_attack_name, - attack_results=atomic_results.completed_results, + atomic_attack_name=step_name, + attack_results=step_result.attack_results, ) - # Check if there were any incomplete objectives - if atomic_results.has_incomplete: - incomplete_count = len(atomic_results.incomplete_objectives) - completed_count = len(atomic_results.completed_results) + # Partial-failure handling. Only the AtomicAttack adapter path stuffs + # ``incomplete_objectives`` into metadata today; custom ScenarioStep + # subclasses opt in by populating the same key, so the same FAILED-state + # path covers any future step that wants partial-failure semantics. + incomplete_objectives = step_result.metadata.get("incomplete_objectives") or [] + if incomplete_objectives: + incomplete_count = len(incomplete_objectives) + completed_in_step = len(step_result.attack_results) logger.error( - f"Atomic attack {i}/{len(self._atomic_attacks)} " - f"('{atomic_attack.atomic_attack_name}') partially completed: " - f"{completed_count} completed, {incomplete_count} incomplete" + f"Atomic attack {step_position}/{total_steps} " + f"('{step_name}') partially completed: " + f"{completed_in_step} completed, {incomplete_count} incomplete" ) - # Log details of each incomplete objective - for obj, exc in atomic_results.incomplete_objectives: + for obj, exc in incomplete_objectives: logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") - # Collect error attack result IDs from the exceptions error_ids = [] - for _, exc in atomic_results.incomplete_objectives: + for _, exc in incomplete_objectives: error_id = getattr(exc, "error_attack_result_id", None) if error_id: error_ids.append(error_id) - # Link error attack results to the scenario result if error_ids: self._memory.update_scenario_error_attacks( scenario_result_id=scenario_result_id, error_attack_result_ids=error_ids, ) - # Mark scenario as failed error_msg = ( - f"Atomic attack '{atomic_attack.atomic_attack_name}' partially failed: " - f"{incomplete_count} of {incomplete_count + completed_count} objectives incomplete. " - f"See attack results for details." + f"Atomic attack '{step_name}' partially failed: " + f"{incomplete_count} of {incomplete_count + completed_in_step} " + f"objectives incomplete. See attack results for details." ) self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", error_message=error_msg, - error_type=type(atomic_results.incomplete_objectives[0][1]).__name__, + error_type=type(incomplete_objectives[0][1]).__name__, ) - # Raise exception with detailed information - raise ValueError(error_msg) from atomic_results.incomplete_objectives[0][1] + raise ValueError(error_msg) from incomplete_objectives[0][1] + logger.info( - f"Atomic attack {i}/{len(self._atomic_attacks)} completed successfully with " - f"{len(atomic_results.completed_results)} results" + f"Atomic attack {step_position}/{total_steps} completed successfully with " + f"{len(step_result.attack_results)} results" ) + progress.update(1) - except Exception as e: - # Exception was raised either by run_async or by our check above - logger.error( - f"Atomic attack {i}/{len(self._atomic_attacks)} " - f"('{atomic_attack.atomic_attack_name}') failed in scenario '{self._name}': {str(e)}" - ) + except Exception as e: + logger.error( + f"Atomic attack {step_position}/{total_steps} " + f"('{last_attempted_step_name}') failed in scenario '{self._name}': {str(e)}" + ) - # Mark scenario as failed if not already done - scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) - if scenario_results and scenario_results[0].scenario_run_state != "FAILED": - self._memory.update_scenario_run_state( - scenario_result_id=scenario_result_id, - scenario_run_state="FAILED", - error_message=str(e), - error_type=type(e).__name__, - ) + # Mark scenario as failed if not already done + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if scenario_results and scenario_results[0].scenario_run_state != "FAILED": + self._memory.update_scenario_run_state( + scenario_result_id=scenario_result_id, + scenario_run_state="FAILED", + error_message=str(e), + error_type=type(e).__name__, + ) - raise + raise logger.info(f"Scenario '{self._name}' completed successfully") @@ -1294,3 +1523,5 @@ async def _execute_scenario_async(self) -> ScenarioResult: except Exception as e: logger.error(f"Scenario '{self._name}' failed with error: {str(e)}") raise + finally: + progress.close() diff --git a/tests/unit/scenario/test_scenario_graph_execution.py b/tests/unit/scenario/test_scenario_graph_execution.py new file mode 100644 index 000000000..bd34e9c46 --- /dev/null +++ b/tests/unit/scenario/test_scenario_graph_execution.py @@ -0,0 +1,401 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Phase 5 — coverage for ``Scenario`` driving execution through ``StrategyGraph``. + +These tests pin the new public surface (``execution_graph``, ``execution_history``, +``_build_execution_graph``, ``_build_default_linear_policy``) plus the contract +that the default linear policy produces the same end-to-end behavior as the +legacy flat loop (max_concurrency propagation, step_identifier stamping, +custom-policy overrides, non-AtomicAttack ``ScenarioStep`` dispatch). +""" + +from typing import ClassVar, cast +from unittest.mock import MagicMock, PropertyMock + +import pytest + +from pyrit.executor.attack.core import AttackExecutorResult +from pyrit.identifiers import ComponentIdentifier +from pyrit.memory import CentralMemory +from pyrit.models import AttackOutcome, AttackResult +from pyrit.scenario import DatasetConfiguration, ScenarioResult +from pyrit.scenario.core import AtomicAttack, BaselinePolicy, Scenario, ScenarioStrategy +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.strategy_graph import ( + PolicyAction, + StrategyGraph, + StrategyPolicy, +) +from pyrit.score import Scorer + +_TEST_SCORER_ID = ComponentIdentifier( + class_name="MockScorer", + class_module="tests.unit.scenarios", +) + + +def _save_results_to_memory(attack_results): + memory = CentralMemory.get_memory_instance() + memory.add_attack_results_to_memory(attack_results=attack_results) + + +def _make_atomic_attack_mock(name: str, attack_result: AttackResult) -> MagicMock: + """Build a fake AtomicAttack whose run_async returns the supplied result.""" + mock_attack = MagicMock() + mock_attack.get_objective_target.return_value = MagicMock() + mock_attack.get_attack_scoring_config.return_value = MagicMock() + + attack = MagicMock(spec=AtomicAttack) + attack.atomic_attack_name = name + attack.display_group = name + attack._attack = mock_attack + type(attack).objectives = PropertyMock(return_value=[attack_result.objective]) + + async def _fake_run(*args, **kwargs): + _save_results_to_memory([attack_result]) + return AttackExecutorResult(completed_results=[attack_result], incomplete_objectives=[]) + + attack.run_async = MagicMock(side_effect=_fake_run) + return attack + + +def _sample_result(index: int) -> AttackResult: + result = AttackResult( + conversation_id=f"conv-{index}", + objective=f"objective-{index}", + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + # Stamp a real atomic_attack_identifier so step_identifier stamping can wrap it. + result.atomic_attack_identifier = ComponentIdentifier( + class_name="MockAttack", + class_module="tests.unit.scenarios", + params={"name": f"attack-{index}"}, + ) + return result + + +class _GraphConcreteScenario(Scenario): + """Concrete Scenario for graph-execution tests. + + Mirrors ``ConcreteScenario`` from test_scenario.py but stays local so we can + swap the execution-graph builder per test without coupling to the broader + test_scenario fixture. + """ + + BASELINE_POLICY: ClassVar[BaselinePolicy] = BaselinePolicy.Forbidden + + def __init__(self, atomic_attacks_to_return=None, **kwargs): + class _TestStrategy(ScenarioStrategy): + TEST = ("test", {"concrete"}) + ALL = ("all", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all"} + + kwargs.setdefault("strategy_class", _TestStrategy) + + if "objective_scorer" not in kwargs: + mock_scorer = MagicMock(spec=Scorer) + mock_scorer.get_identifier.return_value = _TEST_SCORER_ID + mock_scorer.get_scorer_metrics.return_value = None + kwargs["objective_scorer"] = mock_scorer + + super().__init__(**kwargs) + self._atomic_attacks_to_return = atomic_attacks_to_return or [] + + @classmethod + def get_strategy_class(cls): + class _TestStrategy(ScenarioStrategy): + TEST = ("test", {"concrete"}) + ALL = ("all", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all"} + + return _TestStrategy + + @classmethod + def get_default_strategy(cls): + return cls.get_strategy_class().ALL + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + return DatasetConfiguration() + + async def _get_atomic_attacks_async(self): + return self._atomic_attacks_to_return + + +@pytest.fixture +def mock_objective_target(): + target = MagicMock() + target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", + class_module="test", + ) + return target + + +@pytest.mark.usefixtures("patch_central_database") +class TestBuildExecutionGraph: + """Pin the default ``_build_execution_graph`` factory contract.""" + + def test_raises_when_no_steps_and_no_atomic_attacks(self, mock_objective_target): + scenario = _GraphConcreteScenario(name="Empty", version=1) + with pytest.raises(ValueError, match="no steps"): + scenario._build_execution_graph() + + def test_raises_when_explicit_steps_empty(self, mock_objective_target): + scenario = _GraphConcreteScenario(name="Empty", version=1) + with pytest.raises(ValueError, match="no steps"): + scenario._build_execution_graph(steps=[]) + + async def test_default_graph_terminates_at_len_steps(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(3)] + scenario = _GraphConcreteScenario( + name="Default", version=1, atomic_attacks_to_return=attacks, + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + graph = scenario._build_execution_graph() + + assert graph.policy.initial_state == 0 + assert graph.policy.terminal_states == frozenset({3}) + # Three non-terminal states must have actions. + for state in range(3): + assert callable(graph.policy.get_action(state=state)) + + async def test_explicit_steps_override_atomic_attacks(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(3)] + scenario = _GraphConcreteScenario( + name="Default", version=1, atomic_attacks_to_return=attacks, + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + # Build with only the first two — terminal moves down to 2. + graph = scenario._build_execution_graph(steps=attacks[:2]) + assert graph.policy.terminal_states == frozenset({2}) + + +@pytest.mark.usefixtures("patch_central_database") +class TestExecutionGraphPropertyAndHistory: + """``execution_graph`` and ``execution_history`` reflect the active attempt.""" + + def test_graph_and_history_are_empty_before_run(self): + scenario = _GraphConcreteScenario(name="Pre-run", version=1) + assert scenario.execution_graph is None + assert scenario.execution_history == [] + + async def test_run_async_populates_graph_and_history(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] + scenario = _GraphConcreteScenario( + name="Populated", version=1, atomic_attacks_to_return=attacks, + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + await scenario.run_async() + + assert scenario.execution_graph is not None + # Two steps executed; history records both. + assert len(scenario.execution_history) == 2 + names = [r.metadata.get("step_name") for r in scenario.execution_history] + assert names == ["a0", "a1"] + + +@pytest.mark.usefixtures("patch_central_database") +class TestStepIdentifierStamping: + """Default linear path stamps a step_identifier on every persisted AttackResult.""" + + async def test_each_result_has_step_identifier(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] + scenario = _GraphConcreteScenario( + name="Stamping", version=1, atomic_attacks_to_return=attacks, + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + result = await scenario.run_async() + + assert isinstance(result, ScenarioResult) + flat_results = [ar for arr in result.attack_results.values() for ar in arr] + assert flat_results, "expected at least one persisted result" + for ar in flat_results: + assert ar.step_identifier is not None + assert ar.step_identifier.class_name == "ScenarioStep" + # The step name shows up in the params (inlined by ComponentIdentifier.to_dict). + id_dict = ar.step_identifier.to_dict() + assert id_dict["step_name"] in {"a0", "a1"} + assert id_dict["outcome"] == "done" + + async def test_pre_stamped_step_identifier_is_preserved(self, mock_objective_target): + """If a step pre-stamps its own step_identifier, the orchestrator must not overwrite.""" + result_obj = _sample_result(0) + pre_stamped = ComponentIdentifier( + class_name="ScenarioStep", + class_module="custom", + params={"step_name": "custom_step", "outcome": "custom_outcome", "eval_version": 99}, + ) + result_obj.step_identifier = pre_stamped + + attack = _make_atomic_attack_mock("a0", result_obj) + + async def _run_returning_stamped(*args, **kwargs): + _save_results_to_memory([result_obj]) + return AttackExecutorResult(completed_results=[result_obj], incomplete_objectives=[]) + + attack.run_async = MagicMock(side_effect=_run_returning_stamped) + + scenario = _GraphConcreteScenario( + name="Pre-stamped", version=1, atomic_attacks_to_return=[attack], + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + await scenario.run_async() + + stamped = scenario.execution_history[0].attack_results[0].step_identifier + # The orchestrator must not have replaced the pre-stamped identifier with a + # default ScenarioStep-shape one. We compare by params since equality is structural. + assert stamped is not None + assert stamped.params.get("step_name") == "custom_step" + assert stamped.params.get("outcome") == "custom_outcome" + assert stamped.params.get("eval_version") == 99 + + +@pytest.mark.usefixtures("patch_central_database") +class TestMaxConcurrencyPropagation: + """``max_concurrency`` flows from the scenario through the default linear policy.""" + + async def test_atomic_attack_receives_scenario_max_concurrency(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] + scenario = _GraphConcreteScenario( + name="Concurrency", version=1, atomic_attacks_to_return=attacks, + ) + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=7) + + await scenario.run_async() + + for attack in attacks: + attack.run_async.assert_called_once_with(max_concurrency=7, return_partial_on_failure=True) + + +@pytest.mark.usefixtures("patch_central_database") +class TestPartialFailureSurfacing: + """Partial-failure metadata propagates through the step result and raises ValueError.""" + + async def test_incomplete_objectives_in_metadata_raise(self, mock_objective_target): + result = _sample_result(0) + attack = _make_atomic_attack_mock("a0", result) + + async def _run_partial(*args, **kwargs): + _save_results_to_memory([result]) + return AttackExecutorResult( + completed_results=[result], + incomplete_objectives=[("partial-obj", RuntimeError("boom"))], + ) + + attack.run_async = MagicMock(side_effect=_run_partial) + + scenario = _GraphConcreteScenario( + name="Partial", version=1, atomic_attacks_to_return=[attack], + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + with pytest.raises(ValueError, match="partially failed"): + await scenario.run_async() + + +class _CountingStep(ScenarioStep): + """Non-AtomicAttack step that records every process_async call.""" + + def __init__(self, *, name: str) -> None: + self.name = name + self.outputs = ["done"] + self.call_count = 0 + + async def process_async(self) -> ScenarioStepResult: + self.call_count += 1 + return ScenarioStepResult(outcome="done", attack_results=[]) + + +class _CustomStepScenario(_GraphConcreteScenario): + """Scenario that drives non-AtomicAttack steps via an explicit StrategyGraph. + + Overrides ``_get_remaining_atomic_attacks_async`` (which is shaped for + ``AtomicAttack`` instances) and ``_build_execution_graph`` to swap in a + policy over arbitrary ``ScenarioStep`` subclasses without touching the + legacy resume path. + """ + + def __init__(self, *, steps: list[_CountingStep], **kwargs): + super().__init__(**kwargs) + self._custom_steps = steps + + async def _get_remaining_atomic_attacks_async(self): # type: ignore[override] + # The custom steps are not AtomicAttacks; bypass the legacy resume filter + # by returning them as-is. The orchestrator forwards them to our overridden + # ``_build_execution_graph``, which wraps them in a hand-rolled policy. + return self._custom_steps + + def _build_execution_graph(self, *, steps=None): + effective_steps = list(steps) if steps is not None else list(self._custom_steps) + + actions: dict[int, PolicyAction[ScenarioStep, int]] = {} + for index, step in enumerate(effective_steps): + + async def _action( + graph: StrategyGraph[ScenarioStep, int], + _step: ScenarioStep = step, + _next: int = index + 1, + ) -> tuple[int, ScenarioStepResult | None]: + graph.bind_current_step(step=_step) + try: + base = await _step.process_async() + merged = {"step_name": _step.name, **base.metadata} + return _next, ScenarioStepResult( + outcome=base.outcome, + attack_results=base.attack_results, + step_identifier=base.step_identifier, + metadata=merged, + ) + finally: + graph.bind_current_step(step=None) + + actions[index] = _action + + policy = StrategyPolicy( + actions=actions, + initial_state=0, + terminal_states=frozenset({len(effective_steps)}), + ) + return StrategyGraph(policy=policy) + + +@pytest.mark.usefixtures("patch_central_database") +class TestNonAtomicAttackStepDispatch: + """Non-AtomicAttack ``ScenarioStep`` subclasses route through ``process_async``.""" + + async def test_custom_step_routed_through_process_async(self, mock_objective_target): + step_a = _CountingStep(name="custom_a") + step_b = _CountingStep(name="custom_b") + + scenario = _CustomStepScenario( + steps=[step_a, step_b], name="Custom-steps", version=1, + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + # Mirror what initialize_async would do for atomic attacks so the + # orchestrator's progress-bar math sees the right total. The cast satisfies + # the field's nominal ``list[AtomicAttack]`` annotation even though our + # steps are non-atomic ``ScenarioStep`` subclasses. + scenario._atomic_attacks = cast("list", [step_a, step_b]) + + await scenario.run_async() + + assert step_a.call_count == 1 + assert step_b.call_count == 1 + history_names = [r.metadata.get("step_name") for r in scenario.execution_history] + assert history_names == ["custom_a", "custom_b"] From 99fa9dce9e0380b95cd2fdcc81eb2adeff5267f2 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:16:35 -0700 Subject: [PATCH 06/42] MAINT: vendor adaptive scenario from PR #1760 verbatim (Phase 6a) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 6a brings the in-flight adaptive scenario landing (PR #1760, hawestra/text_adaptive_scenario) into this branch as a sibling module so Phase 6b can migrate it onto the new StrategyGraph without blocking on upstream merge order. Files vendored verbatim from the PR head (137597426533f8d7a73d9b34ac1960e321bb928a): * pyrit/scenario/scenarios/adaptive/{__init__.py, adaptive_scenario.py, dispatcher.py, selector.py, text_adaptive.py} * tests/unit/scenario/scenarios/adaptive/{test_dispatcher.py, test_selector.py, test_text_adaptive.py} * doc/code/scenarios/3_adaptive_scenarios.{ipynb, py} * doc/myst.yml — added 3_adaptive_scenarios entry Only edit applied locally: * pyrit/scenario/__init__.py — merged the PR's adaptive export with this branch's existing Phase 0-3 scaffold exports (PolicyAction, StrategyGraph, StrategyPolicy, ScenarioStep, ScenarioStepResult, ScenarioCoreState, ScenarioStateLike, linear_strategy_policy). Re-sorted the __all__ block to keep submodule names grouped. Test counts: vendored adaptive suite runs 63 tests green; full unit suite 8021 passed / 118 skipped / 1 pre-existing ODBC env failure (test_main_prints_startup_message, unrelated). Phase 6b will rewrite AdaptiveScenario to drive its event loop through StrategyGraph + a recurring SELECTING state, deprecating AdaptiveDispatchAttack in favor of an AdaptiveStep whose process_async owns one selector tick. The vendored tests become the regression net for that port. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- doc/code/scenarios/3_adaptive_scenarios.ipynb | 242 +++++++++ doc/code/scenarios/3_adaptive_scenarios.py | 175 ++++++ doc/myst.yml | 1 + pyrit/scenario/__init__.py | 8 +- pyrit/scenario/scenarios/adaptive/__init__.py | 28 + .../scenarios/adaptive/adaptive_scenario.py | 327 +++++++++++ .../scenario/scenarios/adaptive/dispatcher.py | 278 ++++++++++ pyrit/scenario/scenarios/adaptive/selector.py | 182 +++++++ .../scenarios/adaptive/text_adaptive.py | 137 +++++ .../scenarios/adaptive/test_dispatcher.py | 303 +++++++++++ .../scenarios/adaptive/test_selector.py | 225 ++++++++ .../scenarios/adaptive/test_text_adaptive.py | 512 ++++++++++++++++++ 12 files changed, 2416 insertions(+), 2 deletions(-) create mode 100644 doc/code/scenarios/3_adaptive_scenarios.ipynb create mode 100644 doc/code/scenarios/3_adaptive_scenarios.py create mode 100644 pyrit/scenario/scenarios/adaptive/__init__.py create mode 100644 pyrit/scenario/scenarios/adaptive/adaptive_scenario.py create mode 100644 pyrit/scenario/scenarios/adaptive/dispatcher.py create mode 100644 pyrit/scenario/scenarios/adaptive/selector.py create mode 100644 pyrit/scenario/scenarios/adaptive/text_adaptive.py create mode 100644 tests/unit/scenario/scenarios/adaptive/test_dispatcher.py create mode 100644 tests/unit/scenario/scenarios/adaptive/test_selector.py create mode 100644 tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py diff --git a/doc/code/scenarios/3_adaptive_scenarios.ipynb b/doc/code/scenarios/3_adaptive_scenarios.ipynb new file mode 100644 index 000000000..7be2b738e --- /dev/null +++ b/doc/code/scenarios/3_adaptive_scenarios.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Adaptive Scenarios\n", + "\n", + "An **adaptive scenario** doesn't run every attack technique against every objective.\n", + "Instead, it picks which technique to try next per-objective, learns from what worked,\n", + "and stops as soon as one technique succeeds. This concentrates spend on techniques\n", + "that actually work on your target.\n", + "\n", + "## How it works (high level)\n", + "\n", + "For each objective, the scenario tries up to `max_attempts_per_objective` techniques:\n", + "\n", + "- With probability `epsilon`, it **explores** — picks a random technique.\n", + "- Otherwise it **exploits** — picks the technique with the highest observed success\n", + " rate so far.\n", + "- It records the outcome and stops early on success.\n", + "\n", + "Unseen techniques are tried first, so the first few objectives effectively round-robin\n", + "through every technique before the scenario settles on the best performers.\n", + "\n", + "## Adaptive vs. static scenarios\n", + "\n", + "| Feature | Static scenarios | Adaptive scenarios |\n", + "|---------------------|-----------------------------------|------------------------------------|\n", + "| Technique selection | Run every selected technique | Pick per-objective from outcomes |\n", + "| Early stopping | No | Yes — stops on first success |\n", + "| Cost | O(techniques × objectives) | O(max_attempts × objectives) |\n", + "\n", + "`AdaptiveScenario` is the modality-agnostic base class.\n", + "[`TextAdaptive`](../../../pyrit/scenario/scenarios/adaptive/text_adaptive.py) is the\n", + "text subclass used in the examples below." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "from pyrit.registry import TargetRegistry\n", + "from pyrit.scenario import DatasetConfiguration\n", + "from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter\n", + "from pyrit.scenario.scenarios.adaptive import TextAdaptive, harm_category_context\n", + "from pyrit.setup import initialize_from_config_async\n", + "\n", + "await initialize_from_config_async(config_path=Path(\"../../scanner/pyrit_conf.yaml\")) # type: ignore\n", + "\n", + "objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name(\"openai_chat\")\n", + "printer = ConsoleScenarioResultPrinter()" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Basic usage\n", + "\n", + "Defaults: `epsilon=0.2`, `max_attempts_per_objective=3`, the subclass's default datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "scenario = TextAdaptive()\n", + "\n", + "await scenario.initialize_async( # type: ignore\n", + " objective_target=objective_target,\n", + ")\n", + "result = await scenario.run_async() # type: ignore\n", + "await printer.write_async(result) # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## Configuring a run\n", + "\n", + "All the knobs below are constructor or `initialize_async` arguments — combine whichever\n", + "you need on a single scenario instance:\n", + "\n", + "- **`epsilon`** — exploration probability. `0.0` is pure exploit, `1.0` is pure random,\n", + " `0.2` (default) is 20% exploration.\n", + "- **`max_attempts_per_objective`** — caps techniques tried per objective. Higher means\n", + " more chances to succeed and more API calls.\n", + "- **`context_extractor`** — partitions the success-rate table. The default\n", + " `global_context` keeps one shared table; `harm_category_context` learns each harm\n", + " category independently. Custom callables of type `Callable[[SeedAttackGroup], str]`\n", + " are supported.\n", + "- **`seed`** — makes every selection decision deterministic.\n", + "- **`scenario_strategies`** (on `initialize_async`) — restricts which techniques the\n", + " selector can pick from. Use `TextAdaptive.get_strategy_class()` to access the enum.\n", + "\n", + "The cell below exercises all of them at once." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "strategy_class = TextAdaptive.get_strategy_class()\n", + "\n", + "configured_scenario = TextAdaptive(\n", + " epsilon=0.3,\n", + " max_attempts_per_objective=5,\n", + " context_extractor=harm_category_context,\n", + " seed=42,\n", + ")\n", + "\n", + "await configured_scenario.initialize_async( # type: ignore\n", + " objective_target=objective_target,\n", + " scenario_strategies=[strategy_class(\"single_turn\")],\n", + " dataset_config=DatasetConfiguration(\n", + " dataset_names=[\"airt_hate\", \"airt_violence\"],\n", + " max_dataset_size=4,\n", + " ),\n", + ")\n", + "configured_result = await configured_scenario.run_async() # type: ignore\n", + "await printer.write_async(configured_result) # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Resuming a run\n", + "\n", + "Adaptive scenarios are resumable — pass `scenario_result_id=...` to the `TextAdaptive`\n", + "constructor and the run picks up where it left off, with prior outcomes replayed into\n", + "the selector. Resume must use the same configuration as the original run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "resumed_scenario = TextAdaptive(\n", + " epsilon=0.3,\n", + " max_attempts_per_objective=5,\n", + " context_extractor=harm_category_context,\n", + " seed=42,\n", + " scenario_result_id=str(configured_result.id),\n", + ")\n", + "\n", + "await resumed_scenario.initialize_async( # type: ignore\n", + " objective_target=objective_target,\n", + " scenario_strategies=[strategy_class(\"single_turn\")],\n", + " dataset_config=DatasetConfiguration(\n", + " dataset_names=[\"airt_hate\", \"airt_violence\"],\n", + " max_dataset_size=4,\n", + " ),\n", + ")\n", + "resumed_result = await resumed_scenario.run_async() # type: ignore\n", + "await printer.write_async(resumed_result) # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Inspecting which techniques were tried\n", + "\n", + "The dispatcher stamps every objective's `AttackResult.metadata` with:\n", + "\n", + "- `adaptive_context` — the bucket key from the `context_extractor`.\n", + "- `adaptive_attempts` — the ordered list of `{\"technique\", \"outcome\"}` dicts\n", + " recording exactly which techniques the selector picked and what happened.\n", + "\n", + "Walk that metadata to see the per-objective trail and aggregate counts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import Counter\n", + "\n", + "# Per-objective trail\n", + "for results in resumed_result.attack_results.values():\n", + " for r in results:\n", + " attempts = r.metadata.get(\"adaptive_attempts\", [])\n", + " trail = \" → \".join(f\"{a['technique']}({a['outcome']})\" for a in attempts)\n", + " print(f\"[{r.outcome.value:7s}] {r.objective!r}: {trail}\")\n", + "\n", + "# Aggregate per-technique pick counts and success rate across the run\n", + "picks: Counter[str] = Counter()\n", + "wins: Counter[str] = Counter()\n", + "for results in resumed_result.attack_results.values():\n", + " for r in results:\n", + " for step in r.metadata.get(\"adaptive_attempts\", []):\n", + " picks[step[\"technique\"]] += 1\n", + " if step[\"outcome\"] == \"success\":\n", + " wins[step[\"technique\"]] += 1\n", + "\n", + "print(\"\\nTechnique wins / picks rate\")\n", + "for technique, n in picks.most_common():\n", + " print(f\"{technique:20s} {wins[technique]:>4} / {n:<4} {wins[technique] / n:.0%}\")" + ] + } + ], + "metadata": { + "jupytext": { + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/code/scenarios/3_adaptive_scenarios.py b/doc/code/scenarios/3_adaptive_scenarios.py new file mode 100644 index 000000000..96e3320bb --- /dev/null +++ b/doc/code/scenarios/3_adaptive_scenarios.py @@ -0,0 +1,175 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# --- + +# %% [markdown] +# # Adaptive Scenarios +# +# An **adaptive scenario** doesn't run every attack technique against every objective. +# Instead, it picks which technique to try next per-objective, learns from what worked, +# and stops as soon as one technique succeeds. This concentrates spend on techniques +# that actually work on your target. +# +# ## How it works (high level) +# +# For each objective, the scenario tries up to `max_attempts_per_objective` techniques: +# +# - With probability `epsilon`, it **explores** — picks a random technique. +# - Otherwise it **exploits** — picks the technique with the highest observed success +# rate so far. +# - It records the outcome and stops early on success. +# +# Unseen techniques are tried first, so the first few objectives effectively round-robin +# through every technique before the scenario settles on the best performers. +# +# ## Adaptive vs. static scenarios +# +# | Feature | Static scenarios | Adaptive scenarios | +# |---------------------|-----------------------------------|------------------------------------| +# | Technique selection | Run every selected technique | Pick per-objective from outcomes | +# | Early stopping | No | Yes — stops on first success | +# | Cost | O(techniques × objectives) | O(max_attempts × objectives) | +# +# `AdaptiveScenario` is the modality-agnostic base class. +# [`TextAdaptive`](../../../pyrit/scenario/scenarios/adaptive/text_adaptive.py) is the +# text subclass used in the examples below. + +# %% [markdown] +# ## Setup + +# %% +from pathlib import Path + +from pyrit.registry import TargetRegistry +from pyrit.scenario import DatasetConfiguration +from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter +from pyrit.scenario.scenarios.adaptive import TextAdaptive, harm_category_context +from pyrit.setup import initialize_from_config_async + +await initialize_from_config_async(config_path=Path("../../scanner/pyrit_conf.yaml")) # type: ignore + +objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name("openai_chat") +printer = ConsoleScenarioResultPrinter() + +# %% [markdown] +# ## Basic usage +# +# Defaults: `epsilon=0.2`, `max_attempts_per_objective=3`, the subclass's default datasets. + +# %% +scenario = TextAdaptive() + +await scenario.initialize_async( # type: ignore + objective_target=objective_target, +) +result = await scenario.run_async() # type: ignore +await printer.write_async(result) # type: ignore + +# %% [markdown] +# ## Configuring a run +# +# All the knobs below are constructor or `initialize_async` arguments — combine whichever +# you need on a single scenario instance: +# +# - **`epsilon`** — exploration probability. `0.0` is pure exploit, `1.0` is pure random, +# `0.2` (default) is 20% exploration. +# - **`max_attempts_per_objective`** — caps techniques tried per objective. Higher means +# more chances to succeed and more API calls. +# - **`context_extractor`** — partitions the success-rate table. The default +# `global_context` keeps one shared table; `harm_category_context` learns each harm +# category independently. Custom callables of type `Callable[[SeedAttackGroup], str]` +# are supported. +# - **`seed`** — makes every selection decision deterministic. +# - **`scenario_strategies`** (on `initialize_async`) — restricts which techniques the +# selector can pick from. Use `TextAdaptive.get_strategy_class()` to access the enum. +# +# The cell below exercises all of them at once. + +# %% +strategy_class = TextAdaptive.get_strategy_class() + +configured_scenario = TextAdaptive( + epsilon=0.3, + max_attempts_per_objective=5, + context_extractor=harm_category_context, + seed=42, +) + +await configured_scenario.initialize_async( # type: ignore + objective_target=objective_target, + scenario_strategies=[strategy_class("single_turn")], + dataset_config=DatasetConfiguration( + dataset_names=["airt_hate", "airt_violence"], + max_dataset_size=4, + ), +) +configured_result = await configured_scenario.run_async() # type: ignore +await printer.write_async(configured_result) # type: ignore + +# %% [markdown] +# ## Resuming a run +# +# Adaptive scenarios are resumable — pass `scenario_result_id=...` to the `TextAdaptive` +# constructor and the run picks up where it left off, with prior outcomes replayed into +# the selector. Resume must use the same configuration as the original run. + +# %% +resumed_scenario = TextAdaptive( + epsilon=0.3, + max_attempts_per_objective=5, + context_extractor=harm_category_context, + seed=42, + scenario_result_id=str(configured_result.id), +) + +await resumed_scenario.initialize_async( # type: ignore + objective_target=objective_target, + scenario_strategies=[strategy_class("single_turn")], + dataset_config=DatasetConfiguration( + dataset_names=["airt_hate", "airt_violence"], + max_dataset_size=4, + ), +) +resumed_result = await resumed_scenario.run_async() # type: ignore +await printer.write_async(resumed_result) # type: ignore + +# %% [markdown] +# ## Inspecting which techniques were tried +# +# The dispatcher stamps every objective's `AttackResult.metadata` with: +# +# - `adaptive_context` — the bucket key from the `context_extractor`. +# - `adaptive_attempts` — the ordered list of `{"technique", "outcome"}` dicts +# recording exactly which techniques the selector picked and what happened. +# +# Walk that metadata to see the per-objective trail and aggregate counts. + +# %% +from collections import Counter + +# Per-objective trail +for results in resumed_result.attack_results.values(): + for r in results: + attempts = r.metadata.get("adaptive_attempts", []) + trail = " → ".join(f"{a['technique']}({a['outcome']})" for a in attempts) + print(f"[{r.outcome.value:7s}] {r.objective!r}: {trail}") + +# Aggregate per-technique pick counts and success rate across the run +picks: Counter[str] = Counter() +wins: Counter[str] = Counter() +for results in resumed_result.attack_results.values(): + for r in results: + for step in r.metadata.get("adaptive_attempts", []): + picks[step["technique"]] += 1 + if step["outcome"] == "success": + wins[step["technique"]] += 1 + +print("\nTechnique wins / picks rate") +for technique, n in picks.most_common(): + print(f"{technique:20s} {wins[technique]:>4} / {n:<4} {wins[technique] / n:.0%}") diff --git a/doc/myst.yml b/doc/myst.yml index 491d87556..70bda4aee 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -168,6 +168,7 @@ project: children: - file: code/scenarios/1_common_scenario_parameters.ipynb - file: code/scenarios/2_custom_scenario_parameters.ipynb + - file: code/scenarios/3_adaptive_scenarios.ipynb - file: code/registry/0_registry.md children: - file: code/registry/1_class_registry.ipynb diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index c3e6956e2..37337400b 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -39,17 +39,20 @@ # Import scenario submodules directly and register them as virtual subpackages # This allows: from pyrit.scenario.airt import ContentHarms # without needing separate pyrit/scenario/airt/ directories +from pyrit.scenario.scenarios import adaptive as _adaptive_module from pyrit.scenario.scenarios import airt as _airt_module from pyrit.scenario.scenarios import benchmark as _benchmark_module from pyrit.scenario.scenarios import foundry as _foundry_module from pyrit.scenario.scenarios import garak as _garak_module +sys.modules["pyrit.scenario.adaptive"] = _adaptive_module sys.modules["pyrit.scenario.airt"] = _airt_module sys.modules["pyrit.scenario.benchmark"] = _benchmark_module sys.modules["pyrit.scenario.garak"] = _garak_module sys.modules["pyrit.scenario.foundry"] = _foundry_module # Also expose as attributes for IDE support +adaptive = _adaptive_module airt = _airt_module benchmark = _benchmark_module garak = _garak_module @@ -74,9 +77,10 @@ "ScenarioResult", "StrategyGraph", "StrategyPolicy", - "linear_strategy_policy", + "adaptive", "airt", "benchmark", - "garak", "foundry", + "garak", + "linear_strategy_policy", ] diff --git a/pyrit/scenario/scenarios/adaptive/__init__.py b/pyrit/scenario/scenarios/adaptive/__init__.py new file mode 100644 index 000000000..d0bd978c2 --- /dev/null +++ b/pyrit/scenario/scenarios/adaptive/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Adaptive scenario classes.""" + +from pyrit.scenario.scenarios.adaptive.adaptive_scenario import AdaptiveScenario +from pyrit.scenario.scenarios.adaptive.dispatcher import ( + ADAPTIVE_CONTEXT_LABEL, + AdaptiveDispatchAttack, +) +from pyrit.scenario.scenarios.adaptive.selector import ( + AdaptiveTechniqueSelector, + ContextExtractor, + global_context, + harm_category_context, +) +from pyrit.scenario.scenarios.adaptive.text_adaptive import TextAdaptive + +__all__ = [ + "ADAPTIVE_CONTEXT_LABEL", + "AdaptiveDispatchAttack", + "AdaptiveScenario", + "AdaptiveTechniqueSelector", + "ContextExtractor", + "TextAdaptive", + "global_context", + "harm_category_context", +] diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py new file mode 100644 index 000000000..723849ce9 --- /dev/null +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -0,0 +1,327 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``AdaptiveScenario`` — modality-agnostic base for scenarios that pick attack +techniques per-objective using an ``AdaptiveTechniqueSelector``. + +Owns selector wiring, dispatcher construction, per-objective atomic-attack +emission, and resume rehydration. Concrete subclasses (``TextAdaptive``, +future ``ImageAdaptive`` / ``AudioAdaptive``) only declare strategy class, +default datasets, version, and atomic-attack prefix. + +Baseline policy is ``Forbidden``: ``prompt_sending`` participates as one of +the selector's techniques rather than being prepended. +""" + +from __future__ import annotations + +import logging +import random +import uuid +from typing import TYPE_CHECKING, ClassVar + +from pyrit.executor.attack import AttackScoringConfig +from pyrit.scenario.core.atomic_attack import AtomicAttack +from pyrit.scenario.core.attack_technique import AttackTechnique +from pyrit.scenario.core.scenario import BaselinePolicy, Scenario +from pyrit.scenario.scenarios.adaptive.dispatcher import ( + ADAPTIVE_CONTEXT_LABEL, + AdaptiveDispatchAttack, + TechniqueBundle, +) +from pyrit.scenario.scenarios.adaptive.selector import ( + AdaptiveTechniqueSelector, + ContextExtractor, + global_context, +) + +if TYPE_CHECKING: + from pyrit.models import SeedAttackGroup + from pyrit.prompt_target import PromptTarget + from pyrit.score import TrueFalseScorer + +logger = logging.getLogger(__name__) + + +class AdaptiveScenario(Scenario): + """ + Abstract base for adaptive (epsilon-greedy) scenarios. + + Subclasses must implement the standard ``Scenario`` class-method overrides + and declare ``VERSION`` and ``_atomic_attack_prefix``. Selector wiring, + dispatcher construction, per-objective atomic-attack emission, and resume + rehydration are handled here. + """ + + BASELINE_POLICY: ClassVar[BaselinePolicy] = BaselinePolicy.Forbidden + + #: Subclasses must declare a scenario version for memory bookkeeping. + VERSION: ClassVar[int] + + #: Prefix for per-objective atomic-attack names (e.g. ``"adaptive_text"``). + _atomic_attack_prefix: ClassVar[str] = "adaptive" + + def __init__( + self, + *, + objective_scorer: TrueFalseScorer | None = None, + epsilon: float = 0.2, + pool_threshold: int = 3, + max_attempts_per_objective: int = 3, + seed: int | None = None, + context_extractor: ContextExtractor = global_context, + scenario_result_id: str | None = None, + ) -> None: + """ + Args: + objective_scorer (TrueFalseScorer | None): Scorer used to judge each + response. Defaults to the composite scorer from the base class. + epsilon (float): Exploration probability for the selector. Defaults to 0.2. + pool_threshold (int): Minimum per-(context, technique) attempts before + the local estimate overrides the pooled rate. Set to 1 to disable + pooling. Defaults to 3. + max_attempts_per_objective (int): Max techniques per objective. Defaults to 3. + seed (int | None): RNG seed for deterministic selection. Defaults to ``None``. + context_extractor (ContextExtractor): Maps a ``SeedAttackGroup`` to a + context key. Defaults to ``global_context``. + scenario_result_id (str | None): ID of an existing ``ScenarioResult`` to resume. + """ + if not objective_scorer: + objective_scorer = self._get_default_objective_scorer() + self._objective_scorer: TrueFalseScorer = objective_scorer + + self._epsilon = epsilon + self._pool_threshold = pool_threshold + self._max_attempts_per_objective = max_attempts_per_objective + self._seed = seed + self._context_extractor = context_extractor + + super().__init__( + version=self.VERSION, + strategy_class=self.get_strategy_class(), + objective_scorer=objective_scorer, + scenario_result_id=scenario_result_id, + ) + + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + """ + Build one ``AtomicAttack`` per objective. + + Each objective gets a freshly constructed ``AdaptiveDispatchAttack`` + bound to its seed group, but all dispatchers share the same selector + so learning accumulates across objectives. Per-objective, techniques + whose ``seed_technique`` is incompatible with the seed group are + filtered out; objectives left with no compatible techniques are skipped. + + Returns: + list[AtomicAttack]: One ``AtomicAttack`` per objective with at + least one compatible technique. Empty if every seed group + is incompatible with every selected technique. + + Raises: + ValueError: If ``self._objective_target`` is not set, or if + ``_build_techniques_dict`` finds no usable techniques. + """ + if self._objective_target is None: + raise ValueError("objective_target must be set before creating attacks") + + techniques = self._build_techniques_dict(objective_target=self._objective_target) + + selector = AdaptiveTechniqueSelector( + epsilon=self._epsilon, + pool_threshold=self._pool_threshold, + rng=random.Random(self._seed), + ) + # On resume, replay prior attempt outcomes from persisted metadata. + self._rehydrate_selector_from_memory(selector=selector, known_techniques=set(techniques)) + + seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() + atomic_attacks: list[AtomicAttack] = [] + for dataset_name, seed_groups in seed_groups_by_dataset.items(): + for seed_group in seed_groups: + atomic = self._build_atomic_for_seed_group( + dataset_name=dataset_name, + seed_group=seed_group, + techniques=techniques, + selector=selector, + ) + if atomic is not None: + atomic_attacks.append(atomic) + + return atomic_attacks + + def _build_techniques_dict( + self, + *, + objective_target: PromptTarget, + ) -> dict[str, TechniqueBundle]: + """ + Resolve selected strategies into a ``{name: TechniqueBundle}`` map. + + Each bundle carries the inner attack strategy along with the factory's + ``seed_technique`` and ``adversarial_chat`` so the dispatcher can + reproduce the static ``AtomicAttack`` execution path per attempt. + + Returns: + dict[str, TechniqueBundle]: Mapping from technique name to its + bundle, in the order selected strategies were resolved. + + Raises: + ValueError: If no techniques remain after filtering. Includes the + requested techniques and skip reasons. + """ + selected_techniques = sorted({s.value for s in self._scenario_strategies}) + factories = self._get_attack_technique_factories() + scoring_config = AttackScoringConfig(objective_scorer=self._objective_scorer) + + techniques: dict[str, TechniqueBundle] = {} + skipped_no_factory: list[str] = [] + for technique_name in selected_techniques: + factory = factories.get(technique_name) + if factory is None: + skipped_no_factory.append(technique_name) + logger.warning(f"No factory for technique '{technique_name}', skipping.") + continue + technique = factory.create( + objective_target=objective_target, + attack_scoring_config=scoring_config, + ) + techniques[technique_name] = TechniqueBundle( + attack=technique.attack, + seed_technique=technique.seed_technique, + adversarial_chat=factory.adversarial_chat, + ) + + if not techniques: + suffix = f" (skipped, no factory registered: {sorted(skipped_no_factory)})" if skipped_no_factory else "" + raise ValueError( + f"{type(self).__name__}: no usable techniques after resolving strategies. " + f"Check the --strategies selection.{suffix}" + ) + + return techniques + + def _build_atomic_for_seed_group( + self, + *, + dataset_name: str, + seed_group: SeedAttackGroup, + techniques: dict[str, TechniqueBundle], + selector: AdaptiveTechniqueSelector, + ) -> AtomicAttack | None: + """ + Build a single ``AtomicAttack`` for one ``SeedAttackGroup``. + + Filters the technique pool down to those whose ``seed_technique`` (if + any) is compatible with this seed group, then constructs a dedicated + ``AdaptiveDispatchAttack`` bound to this seed group. + + Returns: + AtomicAttack | None: The constructed atomic attack, or ``None`` when + no techniques are compatible (caller skips the objective). + + Raises: + ValueError: If ``self._objective_target`` is not set (defensive + guard; ``_get_atomic_attacks_async`` enforces this earlier). + """ + if self._objective_target is None: # pragma: no cover - defensive + raise ValueError("objective_target must be set before creating attacks") + + compatible: dict[str, TechniqueBundle] = { + name: bundle + for name, bundle in techniques.items() + if bundle.seed_technique is None or seed_group.is_compatible_with_technique(technique=bundle.seed_technique) + } + + if not compatible: + logger.warning( + "AdaptiveScenario: no compatible techniques for seed group in dataset '%s' (objective=%r); skipping.", + dataset_name, + seed_group.objective.value, + ) + return None + + adaptive_context = self._context_extractor(seed_group) + # Prefer the objective's id when available so resume keys stay stable + # across re-fetches of the same seed groups. + objective_id = seed_group.objective.id if seed_group.objective.id else uuid.uuid4() + atomic_attack_name = f"{self._atomic_attack_prefix}_{dataset_name}_{objective_id}" + + dispatcher = AdaptiveDispatchAttack( + objective_target=self._objective_target, + techniques=compatible, + selector=selector, + seed_group=seed_group, + objective_scorer=self._objective_scorer, + max_attempts_per_objective=self._max_attempts_per_objective, + ) + + memory_labels = { + **self._memory_labels, + ADAPTIVE_CONTEXT_LABEL: adaptive_context, + } + return AtomicAttack( + atomic_attack_name=atomic_attack_name, + attack_technique=AttackTechnique(attack=dispatcher), + seed_groups=[seed_group], + objective_scorer=self._objective_scorer, + memory_labels=memory_labels, + display_group=dataset_name, + ) + + def _rehydrate_selector_from_memory( + self, + *, + selector: AdaptiveTechniqueSelector, + known_techniques: set[str], + ) -> None: + """ + Replay persisted dispatch trails into ``selector`` so resume + preserves learned state. + + Iterates every persisted ``AttackResult`` on the resumed + ``ScenarioResult`` and calls ``record_outcome`` once per attempt in + each ``metadata["adaptive_attempts"]`` trail. + + Args: + selector (AdaptiveTechniqueSelector): A freshly built selector to populate. + known_techniques (set[str]): Techniques available in the current run. + Trails referencing unknown techniques (e.g. after a strategies + change) are skipped so replay can't poison the table. + """ + if not self._scenario_result_id: + return + + # Narrow to errors a memory backend would plausibly raise (DB/IO + # failures, integrity issues). Programmer-level errors propagate. + try: + scenario_results = self._memory.get_scenario_results(scenario_result_ids=[self._scenario_result_id]) + except (RuntimeError, OSError, ValueError) as exc: + logger.warning(f"AdaptiveScenario: failed to load prior scenario result for rehydration: {exc}") + return + + if not scenario_results: + return + + replayed = 0 + for results_list in scenario_results[0].attack_results.values(): + for result in results_list: + trail = result.metadata.get("adaptive_attempts") if result.metadata else None + context = result.metadata.get("adaptive_context") if result.metadata else None + if not trail or not context: + continue + for step in trail: + technique = step.get("technique") + outcome = step.get("outcome") + if not technique or technique not in known_techniques: + continue + selector.record_outcome( + context=context, + technique=technique, + success=outcome == "success", + ) + replayed += 1 + + if replayed: + logger.info(f"AdaptiveScenario: rehydrated selector with {replayed} prior attempt(s).") diff --git a/pyrit/scenario/scenarios/adaptive/dispatcher.py b/pyrit/scenario/scenarios/adaptive/dispatcher.py new file mode 100644 index 000000000..46808bfde --- /dev/null +++ b/pyrit/scenario/scenarios/adaptive/dispatcher.py @@ -0,0 +1,278 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``AdaptiveDispatchAttack`` — picks an inner technique per attempt via an +``AdaptiveTechniqueSelector``, runs it, records the outcome, and loops up to +``max_attempts_per_objective`` times. Reads the per-objective context key from +``context.memory_labels[ADAPTIVE_CONTEXT_LABEL]`` (falls back to the global context). + +The dispatcher is bound to a single ``SeedAttackGroup`` at construction time so +it can merge each chosen technique's ``seed_technique`` (when present) into the +seed group before delegating execution to ``AttackExecutor``. +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, replace +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from pyrit.executor.attack.core.attack_executor import AttackExecutor +from pyrit.executor.attack.core.attack_parameters import AttackParameters +from pyrit.executor.attack.core.attack_strategy import AttackContext, AttackStrategy +from pyrit.models import AttackOutcome, AttackResult +from pyrit.scenario.scenarios.adaptive.selector import ( + GLOBAL_CONTEXT, + AdaptiveTechniqueSelector, +) + +if TYPE_CHECKING: + from pyrit.models import SeedAttackGroup, SeedAttackTechniqueGroup + from pyrit.prompt_target import PromptTarget + from pyrit.score import TrueFalseScorer + +logger = logging.getLogger(__name__) + + +# Memory-label keys stamped onto persisted prompt rows so adaptive attempts +# can be filtered/grouped after a run. The scenario stamps the context once +# per objective; the dispatcher stamps technique + attempt index on each try. +ADAPTIVE_CONTEXT_LABEL: str = "_adaptive_context" +"""Per-objective context key (e.g. ``"_global"`` or a harm category).""" +ADAPTIVE_TECHNIQUE_LABEL: str = "_adaptive_technique" +"""Technique chosen by the dispatcher for a given attempt.""" +ADAPTIVE_ATTEMPT_LABEL: str = "_adaptive_attempt" +"""1-based attempt index within the per-objective loop.""" + + +@dataclass(frozen=True) +class TechniqueBundle: + """ + Per-technique bundle consumed by the dispatcher. + + Carries the inner attack strategy alongside the factory-supplied + ``seed_technique`` (if any) and ``adversarial_chat`` (required when the + seed_technique contains a simulated-conversation config). + """ + + attack: AttackStrategy[Any, AttackResult] + seed_technique: SeedAttackTechniqueGroup | None = None + adversarial_chat: PromptTarget | None = None + + +@dataclass +class AdaptiveDispatchContext(AttackContext[AttackParameters]): + """Execution context for ``AdaptiveDispatchAttack`` (no extra state).""" + + +class AdaptiveDispatchAttack(AttackStrategy[AdaptiveDispatchContext, AttackResult]): + """ + Attack that delegates each attempt to one of several inner techniques, + choosing per attempt via an ``AdaptiveTechniqueSelector``. + + For each objective, loops up to ``max_attempts_per_objective`` times: + ask the selector, execute the chosen technique, record the outcome, and + stop early on success. The selector is shared by reference with the + scenario so learning accumulates across objectives. + + The dispatcher is bound to a single ``SeedAttackGroup`` at construction + time. When a chosen technique declares a ``seed_technique``, that group + is merged into the seed group before execution (mirroring the static + ``AtomicAttack`` path). + + On success, the dispatcher returns a fresh ``AttackResult`` copy of the + winning inner result (new ``attack_result_id`` and ``timestamp``) with + the dispatch trail stamped onto ``metadata``. The inner result has + already been persisted by its own post-execute hook, so two rows are + written per successful objective sharing the same ``conversation_id``: + the inner row carries the raw outcome, the outer row carries the + adaptive trail. + """ + + def __init__( + self, + *, + objective_target: PromptTarget, + techniques: dict[str, TechniqueBundle], + selector: AdaptiveTechniqueSelector, + seed_group: SeedAttackGroup, + objective_scorer: TrueFalseScorer | None = None, + max_attempts_per_objective: int = 3, + ) -> None: + """ + Args: + objective_target (PromptTarget): The target inner attacks run against. + Stored for identifier/logging parity; not called directly. + techniques (dict[str, TechniqueBundle]): Mapping from technique name to + its bundle (attack, seed_technique, adversarial_chat). Must be non-empty. + selector (AdaptiveTechniqueSelector): Shared selector state. + seed_group (SeedAttackGroup): The seed group bound to this dispatcher. + Each attempt's chosen technique is applied against this group + (merging the technique's ``seed_technique`` when present). + objective_scorer (TrueFalseScorer | None): Scorer passed through to + techniques that generate simulated conversations. + max_attempts_per_objective (int): Max attempts per objective; >= 1. + Defaults to 3. + + Raises: + ValueError: If ``techniques`` is empty or ``max_attempts_per_objective`` < 1. + """ + if not techniques: + raise ValueError("techniques must contain at least one attack technique") + if max_attempts_per_objective < 1: + raise ValueError(f"max_attempts_per_objective must be >= 1, got {max_attempts_per_objective}") + + super().__init__( + objective_target=objective_target, + context_type=AdaptiveDispatchContext, + params_type=AttackParameters, + logger=logger, + ) + self._techniques = techniques + self._selector = selector + self._seed_group = seed_group + self._objective_scorer = objective_scorer + self._max_attempts = max_attempts_per_objective + # Attempts are inherently sequential (each one reads the selector + # state updated by the previous), so a single shared executor with + # ``max_concurrency=1`` is reused across attempts. + self._executor = AttackExecutor(max_concurrency=1) + + def _validate_context(self, *, context: AdaptiveDispatchContext) -> None: + """ + Ensure the context carries a non-empty objective string. + + Raises: + ValueError: If ``context.objective`` is empty or whitespace-only. + """ + if not context.objective or context.objective.isspace(): + raise ValueError("Attack objective must be provided and non-empty") + + async def _setup_async(self, *, context: AdaptiveDispatchContext) -> None: + """No-op: per-attempt setup is owned by the inner technique's executor.""" + + async def _teardown_async(self, *, context: AdaptiveDispatchContext) -> None: + """No-op: per-attempt teardown is owned by the inner technique's executor.""" + + async def _run_inner_attack_async( + self, + *, + bundle: TechniqueBundle, + attempt_labels: dict[str, str], + ) -> AttackResult: + """ + Execute the chosen technique against this dispatcher's seed group. + + Merges ``bundle.seed_technique`` into the bound ``seed_group`` (when + present) and delegates execution to ``AttackExecutor``. Isolated as a + method so tests can patch the inner-attack call surface. + + Args: + bundle (TechniqueBundle): The chosen technique's attack + seeds + chat. + attempt_labels (dict[str, str]): Memory labels stamped onto this attempt. + + Returns: + AttackResult: The single result produced for this attempt. + + Raises: + RuntimeError: If the executor returned no completed results and no + propagated exception (should be unreachable). + """ + if bundle.seed_technique is not None: + execution_group = self._seed_group.with_technique(technique=bundle.seed_technique) + else: + execution_group = self._seed_group + + executor_result = await self._executor.execute_attack_from_seed_groups_async( + attack=bundle.attack, + seed_groups=[execution_group], + adversarial_chat=bundle.adversarial_chat, + objective_scorer=self._objective_scorer, + memory_labels=attempt_labels, + ) + + if executor_result.completed_results: + return executor_result.completed_results[0] + if executor_result.incomplete_objectives: + raise executor_result.incomplete_objectives[0][1] + raise RuntimeError( # pragma: no cover - defensive + "AttackExecutor returned neither completed nor incomplete results." + ) + + async def _perform_async(self, *, context: AdaptiveDispatchContext) -> AttackResult: + """ + Run the per-objective adaptive loop. + + Resolves the per-objective context key from ``context.memory_labels`` + (falling back to :data:`GLOBAL_CONTEXT`), then loops up to + ``max_attempts_per_objective`` times: select a technique, execute it, + record the outcome, and stop early on success. + + Args: + context (AdaptiveDispatchContext): Execution context. ``memory_labels`` + may carry :data:`ADAPTIVE_CONTEXT_LABEL` to scope the selector. + + Returns: + AttackResult: A fresh dispatcher-owned copy of the final inner + result with the dispatch trail stamped onto ``metadata`` + (see class docstring for the two-row persistence note). + + Raises: + RuntimeError: If the loop somehow ran zero attempts (unreachable + because ``max_attempts_per_objective`` is validated >= 1). + """ + adaptive_context = context.memory_labels.get(ADAPTIVE_CONTEXT_LABEL, GLOBAL_CONTEXT) + technique_names = list(self._techniques.keys()) + + last_result: AttackResult | None = None + trail: list[dict[str, str]] = [] + + for attempt_idx in range(self._max_attempts): + chosen = self._selector.select(context=adaptive_context, techniques=technique_names) + bundle = self._techniques[chosen] + attempt_labels = { + **context.memory_labels, + ADAPTIVE_TECHNIQUE_LABEL: chosen, + ADAPTIVE_ATTEMPT_LABEL: str(attempt_idx + 1), + } + + logger.debug( + "AdaptiveDispatchAttack: attempt %d/%d context=%r technique=%r", + attempt_idx + 1, + self._max_attempts, + adaptive_context, + chosen, + ) + + result = await self._run_inner_attack_async(bundle=bundle, attempt_labels=attempt_labels) + success = result.outcome == AttackOutcome.SUCCESS + self._selector.record_outcome(context=adaptive_context, technique=chosen, success=success) + + trail.append({"technique": chosen, "outcome": result.outcome.value}) + last_result = result + + if success: + break + + # ``max_attempts`` is validated >= 1, so the loop always runs at least + # once. Guard explicitly rather than with ``assert`` (stripped under -O). + if last_result is None: # pragma: no cover - defensive + raise RuntimeError("AdaptiveDispatchAttack ran zero attempts; this should be unreachable.") + # Return a fresh dispatcher-owned ``AttackResult``: the inner attack + # already persisted ``last_result`` via its own post-execute hook, so + # returning it directly would cause a PK conflict on the outer hook. + # ``dataclasses.replace`` copies every field; we override identity + # fields and stamp the trail onto metadata. + return replace( + last_result, + attack_result_id=str(uuid.uuid4()), + timestamp=datetime.now(timezone.utc), + metadata={ + **last_result.metadata, + "adaptive_attempts": trail, + "adaptive_context": adaptive_context, + }, + ) diff --git a/pyrit/scenario/scenarios/adaptive/selector.py b/pyrit/scenario/scenarios/adaptive/selector.py new file mode 100644 index 000000000..d2d9e63a7 --- /dev/null +++ b/pyrit/scenario/scenarios/adaptive/selector.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Epsilon-greedy selector and context extractors for adaptive scenarios.""" + +from __future__ import annotations + +import random +import threading +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyrit.models.seeds.seed_attack_group import SeedAttackGroup + +"""Maps a ``SeedAttackGroup`` to an adaptive context key.""" +ContextExtractor = Callable[["SeedAttackGroup"], str] +"""Default context: all objectives share one selection table.""" +GLOBAL_CONTEXT: str = "_global" +"""Fallback context for seed groups with no harm category metadata.""" +UNCATEGORIZED_CONTEXT: str = "_uncategorized" + + +def global_context(_seed_attack_group: SeedAttackGroup) -> str: + """ + Return a single shared context for all objectives. + + Returns: + str: Always :data:`GLOBAL_CONTEXT`. + """ + return GLOBAL_CONTEXT + + +def harm_category_context(seed_attack_group: SeedAttackGroup) -> str: + """ + Return a context keyed by the sorted, ``|``-joined harm categories. + + Multi-category seeds form their own bucket; sorting makes the key deterministic. + + Returns: + str: The ``|``-joined sorted harm categories, or :data:`UNCATEGORIZED_CONTEXT` + when the seed group has no categories. + """ + categories = seed_attack_group.harm_categories + if not categories: + return UNCATEGORIZED_CONTEXT + return "|".join(sorted(categories)) + + +class AdaptiveTechniqueSelector: + """ + Epsilon-greedy selector over attack techniques. + + Maintains a ``(context, technique) -> (successes, attempts)`` table. With + probability ``epsilon`` picks uniformly at random; otherwise picks the + technique with the highest Laplace-smoothed estimate ``(s + 1) / (n + 1)`` + (unseen techniques start at 1.0). A ``(context, technique)`` cell with + fewer than ``pool_threshold`` attempts falls back to the technique's + pooled rate across all contexts. + + All public methods are guarded by a ``threading.Lock`` so concurrent + callers cannot corrupt the table. The lock makes individual ops atomic, + not the overall select → execute → record sequence. + """ + + # Tolerance for tiebreaking on float estimates (current estimates are exact + # rationals; this guards against future estimator changes). + _TIE_TOL: float = 1e-12 + + def __init__( + self, + *, + epsilon: float = 0.2, + pool_threshold: int = 3, + rng: random.Random | None = None, + ) -> None: + """ + Args: + epsilon (float): Exploration probability in [0.0, 1.0]. Defaults to 0.2. + pool_threshold (int): Minimum per-(context, technique) attempts before + the local estimate replaces the pooled rate. Must be >= 1; set to 1 + to disable pooling. Defaults to 3. + rng (random.Random | None): RNG for reproducible decisions. Defaults + to a fresh unseeded ``random.Random()``. + + Raises: + ValueError: If ``epsilon`` is outside [0.0, 1.0] or ``pool_threshold`` < 1. + """ + if not 0.0 <= epsilon <= 1.0: + raise ValueError(f"epsilon must be in [0.0, 1.0], got {epsilon}") + if pool_threshold < 1: + raise ValueError(f"pool_threshold must be >= 1, got {pool_threshold}") + + self._epsilon = epsilon + self._pool_threshold = pool_threshold + self._rng = rng if rng is not None else random.Random() + self._counts: dict[tuple[str, str], tuple[int, int]] = {} + # Per-technique pooled counts, kept in sync with ``_counts`` so the + # pooled-backoff branch in ``_estimate`` is O(1). + self._global_counts: dict[str, tuple[int, int]] = {} + # Guards _counts, _global_counts, and _rng against concurrent callers. + self._lock = threading.Lock() + + def select(self, *, context: str, techniques: Sequence[str]) -> str: + """ + Pick the next technique to try for ``context``. + + Args: + context (str): The context key. + techniques (Sequence[str]): Candidate technique names. + + Returns: + str: The chosen technique name. + + Raises: + ValueError: If ``techniques`` is empty. + """ + technique_list = list(techniques) + if not technique_list: + raise ValueError("techniques must contain at least one entry") + + with self._lock: + if self._rng.random() < self._epsilon: + return self._rng.choice(technique_list) + + estimates = {t: self._estimate(context=context, technique=t) for t in technique_list} + best = max(estimates.values()) + winners = [t for t, value in estimates.items() if value >= best - self._TIE_TOL] + return self._rng.choice(winners) + + def record_outcome(self, *, context: str, technique: str, success: bool) -> None: + """ + Record the outcome of an attempt. + + Args: + context (str): The context key the decision was made under. + technique (str): The technique that was tried. + success (bool): Whether the attempt succeeded. + """ + with self._lock: + successes, attempts = self._counts.get((context, technique), (0, 0)) + attempts += 1 + if success: + successes += 1 + self._counts[(context, technique)] = (successes, attempts) + + global_successes, global_attempts = self._global_counts.get(technique, (0, 0)) + global_attempts += 1 + if success: + global_successes += 1 + self._global_counts[technique] = (global_successes, global_attempts) + + def success_rate(self, *, context: str, technique: str) -> float: + """Return the Laplace-smoothed estimate ``(s + 1) / (n + 1)`` used for exploitation.""" + with self._lock: + return self._estimate(context=context, technique=technique) + + def counts(self, *, context: str, technique: str) -> tuple[int, int]: + """Return raw ``(successes, attempts)`` for a ``(context, technique)`` cell.""" + with self._lock: + return self._counts.get((context, technique), (0, 0)) + + def snapshot(self) -> dict[tuple[str, str], tuple[int, int]]: + """Return a shallow copy of the full counts table (for logging/debug).""" + with self._lock: + return dict(self._counts) + + def _estimate(self, *, context: str, technique: str) -> float: + """ + Estimate for ``(context, technique)``; falls back to pooled rate below + ``pool_threshold`` local attempts. + + Callers must already hold ``self._lock``. + + Returns: + float: Laplace-smoothed success-rate estimate in ``(0, 1)``. + """ + local_s, local_n = self._counts.get((context, technique), (0, 0)) + if local_n >= self._pool_threshold: + return (local_s + 1) / (local_n + 1) + global_s, global_n = self._global_counts.get(technique, (0, 0)) + return (global_s + 1) / (global_n + 1) diff --git a/pyrit/scenario/scenarios/adaptive/text_adaptive.py b/pyrit/scenario/scenarios/adaptive/text_adaptive.py new file mode 100644 index 000000000..4bbbe7ff4 --- /dev/null +++ b/pyrit/scenario/scenarios/adaptive/text_adaptive.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``TextAdaptive`` — text adaptive scenario. + +Picks attack techniques per-objective using an epsilon-greedy selector +informed by observed success rates. Runs up to ``max_attempts_per_objective`` +techniques per objective and stops early on success. The available techniques +come from the selected scenario strategies (``--strategies single_turn`` +restricts to single-turn techniques, etc.). +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, ClassVar + +from pyrit.common import apply_defaults +from pyrit.registry.tag_query import TagQuery +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.scenarios.adaptive.adaptive_scenario import AdaptiveScenario +from pyrit.scenario.scenarios.adaptive.selector import ( + ContextExtractor, + global_context, +) + +if TYPE_CHECKING: + from pyrit.scenario.core.scenario_strategy import ScenarioStrategy + from pyrit.score import TrueFalseScorer + +logger = logging.getLogger(__name__) + + +def _build_text_adaptive_strategy() -> type[ScenarioStrategy]: + """ + Build the strategy enum from the core scenario-techniques catalog. + + Returns: + type[ScenarioStrategy]: The dynamically-built strategy enum class. + """ + from pyrit.registry.object_registries.attack_technique_registry import ( + AttackTechniqueRegistry, + ) + from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES + + return AttackTechniqueRegistry.build_strategy_class_from_specs( # type: ignore[return-value, ty:invalid-return-type] + class_name="TextAdaptiveStrategy", + specs=SCENARIO_TECHNIQUES, + aggregate_tags={ + "default": TagQuery.any_of("default"), + "single_turn": TagQuery.any_of("single_turn"), + "multi_turn": TagQuery.any_of("multi_turn"), + }, + ) + + +class TextAdaptive(AdaptiveScenario): + """ + Adaptive text-attack scenario. + + Selects techniques per-objective via an epsilon-greedy selector over the + set of selected strategies. ``prompt_sending`` participates as one of the + selector's techniques rather than being prepended as a baseline. + """ + + VERSION: int = 1 + _atomic_attack_prefix: ClassVar[str] = "adaptive" + _cached_strategy_class: ClassVar[type[ScenarioStrategy] | None] = None + + @classmethod + def get_strategy_class(cls) -> type[ScenarioStrategy]: + """Return the strategy enum for this scenario, building it once on first access.""" + if cls._cached_strategy_class is None: + cls._cached_strategy_class = _build_text_adaptive_strategy() + return cls._cached_strategy_class + + @classmethod + def get_default_strategy(cls) -> ScenarioStrategy: + """Return the default strategy aggregate (resolves to every ``default``-tagged technique).""" + strategy_class = cls.get_strategy_class() + return strategy_class("default") + + @classmethod + def required_datasets(cls) -> list[str]: + """Return the dataset names this scenario expects when no override is provided.""" + return [ + "airt_hate", + "airt_fairness", + "airt_violence", + "airt_sexual", + "airt_harassment", + "airt_misinformation", + "airt_leakage", + ] + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + """Return the default :class:`DatasetConfiguration` (required datasets, capped at 4 per dataset).""" + return DatasetConfiguration(dataset_names=cls.required_datasets(), max_dataset_size=4) + + @apply_defaults + def __init__( + self, + *, + objective_scorer: TrueFalseScorer | None = None, + epsilon: float = 0.2, + pool_threshold: int = 3, + max_attempts_per_objective: int = 3, + seed: int | None = None, + context_extractor: ContextExtractor = global_context, + scenario_result_id: str | None = None, + ) -> None: + """ + Args: + objective_scorer (TrueFalseScorer | None): Scorer used to judge each + response. Defaults to the composite scorer from the base class. + epsilon (float): Exploration probability for the selector. Defaults to 0.2. + pool_threshold (int): Minimum per-(context, technique) attempts before + the local estimate overrides the pooled rate. Set to 1 to disable + pooling. Defaults to 3. + max_attempts_per_objective (int): Max techniques per objective. Defaults to 3. + seed (int | None): RNG seed for deterministic selection. Defaults to ``None``. + context_extractor (ContextExtractor): Maps a ``SeedAttackGroup`` to a + context key. Defaults to ``global_context``. Use + ``harm_category_context`` to partition by harm category. + scenario_result_id (str | None): ID of an existing ``ScenarioResult`` to resume. + """ + super().__init__( + objective_scorer=objective_scorer, + epsilon=epsilon, + pool_threshold=pool_threshold, + max_attempts_per_objective=max_attempts_per_objective, + seed=seed, + context_extractor=context_extractor, + scenario_result_id=scenario_result_id, + ) diff --git a/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py b/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py new file mode 100644 index 000000000..4be4ffbb6 --- /dev/null +++ b/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import random +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.executor.attack.core.attack_parameters import AttackParameters +from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup, SeedObjective +from pyrit.scenario.scenarios.adaptive.dispatcher import ( + ADAPTIVE_ATTEMPT_LABEL, + ADAPTIVE_CONTEXT_LABEL, + ADAPTIVE_TECHNIQUE_LABEL, + AdaptiveDispatchAttack, + AdaptiveDispatchContext, + TechniqueBundle, +) +from pyrit.scenario.scenarios.adaptive.selector import ( + GLOBAL_CONTEXT, + AdaptiveTechniqueSelector, +) + + +def _make_bundle(*, name: str, outcomes: list[AttackOutcome], seed_technique=None) -> TechniqueBundle: + """Build a TechniqueBundle whose attack stub yields the given outcomes in order. + + The dispatcher routes execution through ``_run_inner_attack_async``; tests + patch that method directly so we only need a placeholder attack here. + """ + attack = MagicMock(name=f"attack-{name}") + attack._outcomes = outcomes + attack._name = name + return TechniqueBundle(attack=attack, seed_technique=seed_technique) + + +def _make_context(*, objective: str = "obj", labels: dict[str, str] | None = None) -> AdaptiveDispatchContext: + return AdaptiveDispatchContext(params=AttackParameters(objective=objective, memory_labels=labels or {})) + + +def _patch_inner( + *, + dispatcher: AdaptiveDispatchAttack, + bundles: dict[str, TechniqueBundle], +) -> AsyncMock: + """Replace ``_run_inner_attack_async`` with a stub backed by per-bundle outcomes. + + Returns the AsyncMock so tests can introspect call history (kwargs include + ``bundle`` and ``attempt_labels``). + """ + # Each call consumes one outcome from the chosen bundle's deque. + name_for_attack = {id(b.attack): name for name, b in bundles.items()} + counters: dict[str, int] = dict.fromkeys(bundles, 0) + + async def _stub(*, bundle: TechniqueBundle, attempt_labels: dict[str, str]) -> AttackResult: + name = name_for_attack[id(bundle.attack)] + idx = counters[name] + counters[name] = idx + 1 + outcome = bundle.attack._outcomes[idx] + return AttackResult( + conversation_id=f"conv-{name}-{idx}", + objective="obj", + outcome=outcome, + ) + + inner_mock = AsyncMock(side_effect=_stub) + dispatcher._run_inner_attack_async = inner_mock # type: ignore[method-assign] + return inner_mock + + +@pytest.fixture +def selector() -> AdaptiveTechniqueSelector: + # epsilon=0 makes selection deterministic given the table. + return AdaptiveTechniqueSelector(epsilon=0.0, pool_threshold=1, rng=random.Random(0)) + + +@pytest.fixture +def target() -> MagicMock: + return MagicMock(name="objective_target") + + +@pytest.fixture +def seed_group() -> SeedAttackGroup: + return SeedAttackGroup(seeds=[SeedObjective(value="obj")]) + + +class TestInit: + @pytest.mark.usefixtures("patch_central_database") + def test_init_rejects_empty_techniques(self, target, selector, seed_group): + with pytest.raises(ValueError, match="techniques"): + AdaptiveDispatchAttack( + objective_target=target, + techniques={}, + selector=selector, + seed_group=seed_group, + ) + + @pytest.mark.parametrize("bad_max", [0, -1]) + @pytest.mark.usefixtures("patch_central_database") + def test_init_rejects_invalid_max_attempts(self, target, selector, seed_group, bad_max): + with pytest.raises(ValueError, match="max_attempts_per_objective"): + AdaptiveDispatchAttack( + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=bad_max, + ) + + +@pytest.mark.usefixtures("patch_central_database") +class TestPerform: + async def test_stops_on_first_success(self, target, selector, seed_group): + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS]), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]), + } + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=5, + ) + inner = _patch_inner(dispatcher=dispatcher, bundles=bundles) + + result = await dispatcher._perform_async(context=_make_context()) + + assert result.outcome == AttackOutcome.SUCCESS + assert inner.call_count == 1 + + async def test_retries_until_max_attempts_on_failure(self, target, selector, seed_group): + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.FAILURE] * 3), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.FAILURE] * 3), + } + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=3, + ) + inner = _patch_inner(dispatcher=dispatcher, bundles=bundles) + + result = await dispatcher._perform_async(context=_make_context()) + + assert result.outcome == AttackOutcome.FAILURE + assert inner.call_count == 3 + + async def test_updates_selector_on_each_attempt(self, target, selector, seed_group): + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.FAILURE, AttackOutcome.SUCCESS]), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]), + } + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=3, + ) + inner = _patch_inner(dispatcher=dispatcher, bundles=bundles) + + await dispatcher._perform_async(context=_make_context()) + + total_attempts = sum(selector.counts(context=GLOBAL_CONTEXT, technique=t)[1] for t in ("a", "b")) + assert total_attempts == inner.call_count + + async def test_passes_attempt_labels_to_inner(self, target, selector, seed_group): + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])} + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + ) + inner = _patch_inner(dispatcher=dispatcher, bundles=bundles) + + await dispatcher._perform_async(context=_make_context(labels={"foo": "bar"})) + + labels = inner.call_args.kwargs["attempt_labels"] + assert labels["foo"] == "bar" # caller labels preserved + assert labels[ADAPTIVE_TECHNIQUE_LABEL] == "a" + assert labels[ADAPTIVE_ATTEMPT_LABEL] == "1" + + async def test_uses_adaptive_context_from_label(self, target, selector, seed_group): + # Two techniques; one has been heavily rewarded under context "violence" only. + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS]), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]), + } + for _ in range(5): + selector.record_outcome(context="violence", technique="b", success=True) + for _ in range(5): + selector.record_outcome(context="violence", technique="a", success=False) + + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + ) + inner = _patch_inner(dispatcher=dispatcher, bundles=bundles) + ctx = _make_context(labels={ADAPTIVE_CONTEXT_LABEL: "violence"}) + await dispatcher._perform_async(context=ctx) + + # Exploit should have picked "b" first. + chosen_bundle = inner.call_args.kwargs["bundle"] + assert chosen_bundle is bundles["b"] + + async def test_falls_back_to_global_context_when_label_missing(self, target, selector, seed_group): + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])} + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + ) + _patch_inner(dispatcher=dispatcher, bundles=bundles) + await dispatcher._perform_async(context=_make_context(labels={})) + + # The global context bucket received the update. + assert selector.counts(context=GLOBAL_CONTEXT, technique="a") == (1, 1) + + async def test_metadata_records_adaptive_trail(self, target, selector, seed_group): + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.FAILURE, AttackOutcome.SUCCESS])} + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=3, + ) + _patch_inner(dispatcher=dispatcher, bundles=bundles) + result = await dispatcher._perform_async(context=_make_context()) + + trail = result.metadata["adaptive_attempts"] + assert trail == [ + {"technique": "a", "outcome": "failure"}, + {"technique": "a", "outcome": "success"}, + ] + assert result.metadata["adaptive_context"] == GLOBAL_CONTEXT + + async def test_returns_fresh_result_distinct_from_inner(self, target, selector, seed_group): + # The dispatcher must NOT return the inner attack's ``AttackResult`` + # instance — doing so would cause a duplicate-PK insert when both the + # inner and the dispatcher's ``execute_async`` post-execute hooks try + # to persist the same row. Verify the returned result has a fresh + # ``attack_result_id`` while preserving the inner's identifying fields + # and stamping the dispatch trail. + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])} + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + ) + inner_ids: list[str] = [] + + async def _spy(*, bundle, attempt_labels): + inner_result = AttackResult( + conversation_id="conv-a-0", + objective="obj", + outcome=AttackOutcome.SUCCESS, + ) + inner_ids.append(inner_result.attack_result_id) + return inner_result + + dispatcher._run_inner_attack_async = AsyncMock(side_effect=_spy) # type: ignore[method-assign] + + result = await dispatcher._perform_async(context=_make_context()) + + assert len(inner_ids) == 1 + assert result.attack_result_id != inner_ids[0] + assert result.conversation_id # carried over from inner + assert result.outcome == AttackOutcome.SUCCESS + assert result.metadata["adaptive_attempts"] == [{"technique": "a", "outcome": "success"}] + assert result.metadata["adaptive_context"] == GLOBAL_CONTEXT + + +@pytest.mark.usefixtures("patch_central_database") +class TestValidate: + @pytest.mark.parametrize("bad_objective", ["", " ", "\n\t"]) + def test_validate_rejects_empty_objective(self, target, selector, seed_group, bad_objective): + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + with pytest.raises(ValueError, match="objective"): + dispatcher._validate_context(context=_make_context(objective=bad_objective)) + + def test_validate_accepts_normal_objective(self, target, selector, seed_group): + dispatcher = AdaptiveDispatchAttack( + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + # Does not raise. + dispatcher._validate_context(context=_make_context(objective="ok")) diff --git a/tests/unit/scenario/scenarios/adaptive/test_selector.py b/tests/unit/scenario/scenarios/adaptive/test_selector.py new file mode 100644 index 000000000..2daba3b70 --- /dev/null +++ b/tests/unit/scenario/scenarios/adaptive/test_selector.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import random +from unittest.mock import MagicMock + +import pytest + +from pyrit.scenario.scenarios.adaptive.selector import ( + GLOBAL_CONTEXT, + UNCATEGORIZED_CONTEXT, + AdaptiveTechniqueSelector, + global_context, + harm_category_context, +) + +TECHNIQUES = ["a", "b", "c", "d"] + + +def _seeded_selector(*, epsilon: float = 0.0, pool_threshold: int = 3, seed: int = 0) -> AdaptiveTechniqueSelector: + return AdaptiveTechniqueSelector( + epsilon=epsilon, + pool_threshold=pool_threshold, + rng=random.Random(seed), + ) + + +class TestAdaptiveTechniqueSelectorInit: + def test_init_defaults(self): + selector = AdaptiveTechniqueSelector() + assert selector.snapshot() == {} + + @pytest.mark.parametrize("bad_epsilon", [-0.1, 1.1, 2.0, -1.0]) + def test_init_rejects_out_of_range_epsilon(self, bad_epsilon): + with pytest.raises(ValueError, match="epsilon"): + AdaptiveTechniqueSelector(epsilon=bad_epsilon) + + def test_init_rejects_pool_threshold_below_one(self): + with pytest.raises(ValueError, match="pool_threshold"): + AdaptiveTechniqueSelector(pool_threshold=0) + with pytest.raises(ValueError, match="pool_threshold"): + AdaptiveTechniqueSelector(pool_threshold=-1) + + +class TestAdaptiveTechniqueSelectorSelect: + def test_select_empty_techniques_raises(self): + selector = _seeded_selector() + with pytest.raises(ValueError, match="techniques"): + selector.select(context=GLOBAL_CONTEXT, techniques=[]) + + def test_select_all_unseen_ties_resolved_randomly(self): + # With epsilon=0 and an empty table, every technique has estimate 1/1=1.0, + # so the result is the seeded random tiebreak. Different seeds should + # be able to produce different winners. + winners = {_seeded_selector(seed=s).select(context=GLOBAL_CONTEXT, techniques=TECHNIQUES) for s in range(50)} + assert len(winners) > 1 + assert winners.issubset(set(TECHNIQUES)) + + def test_select_exploits_clear_winner(self): + selector = _seeded_selector(pool_threshold=1) + # Give "b" a track record of pure success, others pure failure. + for _ in range(5): + selector.record_outcome(context=GLOBAL_CONTEXT, technique="b", success=True) + for technique in ("a", "c", "d"): + for _ in range(5): + selector.record_outcome(context=GLOBAL_CONTEXT, technique=technique, success=False) + + # With epsilon=0, every selection must exploit the winner. + for _ in range(20): + assert selector.select(context=GLOBAL_CONTEXT, techniques=TECHNIQUES) == "b" + + def test_select_epsilon_one_is_pure_random(self): + selector = _seeded_selector(epsilon=1.0) + # Bias the table heavily toward "a"; with epsilon=1 it must still be ignored. + for _ in range(20): + selector.record_outcome(context=GLOBAL_CONTEXT, technique="a", success=True) + + picks = [selector.select(context=GLOBAL_CONTEXT, techniques=TECHNIQUES) for _ in range(200)] + assert set(picks) == set(TECHNIQUES) + + def test_select_epsilon_zero_never_explores(self): + selector = _seeded_selector(epsilon=0.0, pool_threshold=1) + for _ in range(3): + selector.record_outcome(context=GLOBAL_CONTEXT, technique="a", success=True) + # Make the other techniques tried-and-failed so they fall below "a"'s estimate; + # unseen techniques would otherwise tie at the optimistic 1.0. + for technique in ("b", "c", "d"): + selector.record_outcome(context=GLOBAL_CONTEXT, technique=technique, success=False) + for _ in range(50): + assert selector.select(context=GLOBAL_CONTEXT, techniques=TECHNIQUES) == "a" + + def test_select_cold_start_round_robins(self): + # Optimistic init + epsilon=0: untried techniques tie at 1.0 and beat tried-and-failed + # techniques (1/2 = 0.5). So the first failures push each technique to "tried" exactly once + # before any technique gets tried twice. + selector = _seeded_selector(pool_threshold=1) + tried: list[str] = [] + for _ in range(len(TECHNIQUES)): + technique = selector.select(context=GLOBAL_CONTEXT, techniques=TECHNIQUES) + tried.append(technique) + selector.record_outcome(context=GLOBAL_CONTEXT, technique=technique, success=False) + assert sorted(tried) == sorted(TECHNIQUES) + + +class TestAdaptiveTechniqueSelectorUpdate: + def test_record_outcome_accumulates_counts(self): + selector = _seeded_selector() + selector.record_outcome(context="ctx", technique="a", success=True) + selector.record_outcome(context="ctx", technique="a", success=False) + selector.record_outcome(context="ctx", technique="a", success=True) + assert selector.counts(context="ctx", technique="a") == (2, 3) + + def test_record_outcome_separate_contexts_are_independent(self): + selector = _seeded_selector() + selector.record_outcome(context="x", technique="a", success=True) + selector.record_outcome(context="y", technique="a", success=False) + assert selector.counts(context="x", technique="a") == (1, 1) + assert selector.counts(context="y", technique="a") == (0, 1) + + def test_counts_default_zero_for_unseen(self): + selector = _seeded_selector() + assert selector.counts(context="missing", technique="missing") == (0, 0) + + def test_record_outcome_keeps_pooled_global_counts_in_sync(self): + # Pooled-global counts back the O(1) pooled-backoff branch in _estimate. + # They must aggregate across contexts for a given technique. + selector = _seeded_selector(pool_threshold=5) + selector.record_outcome(context="x", technique="a", success=True) + selector.record_outcome(context="y", technique="a", success=False) + selector.record_outcome(context="z", technique="a", success=True) + selector.record_outcome(context="x", technique="b", success=True) + + # Below the local threshold, _estimate must use the pooled-global rate. + # technique "a": 2 successes / 3 attempts -> (2+1)/(3+1) = 0.75 + assert selector.success_rate(context="new_ctx", technique="a") == pytest.approx(0.75) + # technique "b": 1/1 -> (1+1)/(1+1) = 1.0 + assert selector.success_rate(context="new_ctx", technique="b") == pytest.approx(1.0) + # Unseen technique "c" -> (0+1)/(0+1) = 1.0 + assert selector.success_rate(context="new_ctx", technique="c") == pytest.approx(1.0) + + +class TestAdaptiveTechniqueSelectorEstimate: + def test_success_rate_unseen_is_one(self): + # Optimistic init: (0 + 1) / (0 + 1) = 1.0 + selector = _seeded_selector() + assert selector.success_rate(context="ctx", technique="a") == pytest.approx(1.0) + + def test_success_rate_local_when_above_threshold(self): + selector = _seeded_selector(pool_threshold=2) + for _ in range(3): + selector.record_outcome(context="ctx", technique="a", success=True) + # (3 + 1) / (3 + 1) = 1.0 + assert selector.success_rate(context="ctx", technique="a") == pytest.approx(1.0) + + def test_success_rate_pools_when_below_threshold(self): + selector = _seeded_selector(pool_threshold=5) + # Local cell has only 1 attempt (below threshold). + selector.record_outcome(context="ctx", technique="a", success=False) + # Other contexts have plenty of data for technique "a". + for _ in range(10): + selector.record_outcome(context="other", technique="a", success=True) + # Pooled estimate = (10 + 0 + 1) / (10 + 1 + 1) = 11/12. + assert selector.success_rate(context="ctx", technique="a") == pytest.approx(11 / 12) + + +class TestContextExtractors: + def test_global_context_is_constant(self): + sg = MagicMock() + assert global_context(sg) == GLOBAL_CONTEXT + + def test_harm_category_context_joins_sorted_categories(self): + sg = MagicMock() + sg.harm_categories = ["violence", "hate"] + # Multi-category seeds form their own bucket; sorting keeps the key deterministic. + assert harm_category_context(sg) == "hate|violence" + + def test_harm_category_context_single_category(self): + sg = MagicMock() + sg.harm_categories = ["violence"] + assert harm_category_context(sg) == "violence" + + def test_harm_category_context_falls_back_when_empty(self): + sg = MagicMock() + sg.harm_categories = [] + assert harm_category_context(sg) == UNCATEGORIZED_CONTEXT + + def test_harm_category_context_falls_back_when_none(self): + sg = MagicMock() + sg.harm_categories = None + assert harm_category_context(sg) == UNCATEGORIZED_CONTEXT + + +class TestAdaptiveTechniqueSelectorConcurrency: + """Concurrent record_outcome / select calls must not corrupt counts.""" + + def test_concurrent_record_outcome_preserves_total_attempts(self): + import threading + + selector = _seeded_selector(pool_threshold=1) + threads_per_arm = 8 + attempts_per_thread = 100 + techniques = ["a", "b", "c", "d"] + + def worker(technique: str, success_pattern: list[bool]) -> None: + for ok in success_pattern: + selector.record_outcome(context=GLOBAL_CONTEXT, technique=technique, success=ok) + + threads: list[threading.Thread] = [] + expected_successes: dict[str, int] = dict.fromkeys(techniques, 0) + for t in techniques: + for i in range(threads_per_arm): + pattern = [(j + i) % 2 == 0 for j in range(attempts_per_thread)] + expected_successes[t] += sum(pattern) + threads.append(threading.Thread(target=worker, args=(t, pattern))) + + for th in threads: + th.start() + for th in threads: + th.join() + + # Every increment landed: no lost updates from interleaved read-modify-write. + for t in techniques: + successes, attempts = selector.counts(context=GLOBAL_CONTEXT, technique=t) + assert attempts == threads_per_arm * attempts_per_thread + assert successes == expected_successes[t] diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py new file mode 100644 index 000000000..12b1a45e2 --- /dev/null +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -0,0 +1,512 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the ``TextAdaptive`` scenario.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import SeedAttackGroup, SeedObjective +from pyrit.prompt_target import PromptTarget +from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.scenario import BaselinePolicy +from pyrit.scenario.scenarios.adaptive.dispatcher import ( + ADAPTIVE_CONTEXT_LABEL, + AdaptiveDispatchAttack, +) +from pyrit.scenario.scenarios.adaptive.selector import ( + GLOBAL_CONTEXT, + harm_category_context, +) +from pyrit.scenario.scenarios.adaptive.text_adaptive import TextAdaptive +from pyrit.score import TrueFalseScorer + +_MOCK_MANY_SHOT_EXAMPLES = [{"question": f"q{i}", "answer": f"a{i}"} for i in range(100)] + + +def _mock_id(name: str) -> ComponentIdentifier: + return ComponentIdentifier(class_name=name, class_module="test") + + +@pytest.fixture +def mock_objective_target() -> MagicMock: + mock = MagicMock(spec=PromptTarget) + mock.get_identifier.return_value = _mock_id("MockObjectiveTarget") + return mock + + +@pytest.fixture +def mock_objective_scorer() -> MagicMock: + mock = MagicMock(spec=TrueFalseScorer) + mock.get_identifier.return_value = _mock_id("MockObjectiveScorer") + return mock + + +@pytest.fixture(autouse=True) +def reset_technique_registry(): + """Reset registries and the cached strategy class between tests.""" + from pyrit.registry import TargetRegistry + + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + TextAdaptive._cached_strategy_class = None + yield + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + TextAdaptive._cached_strategy_class = None + + +@pytest.fixture(autouse=True) +def patch_many_shot_load(): + with patch( + "pyrit.executor.attack.single_turn.many_shot_jailbreak.load_many_shot_jailbreaking_dataset", + return_value=_MOCK_MANY_SHOT_EXAMPLES, + ): + yield + + +@pytest.fixture +def mock_runtime_env(): + with patch.dict( + "os.environ", + { + "OPENAI_CHAT_ENDPOINT": "https://test.openai.azure.com/", + "OPENAI_CHAT_KEY": "test-key", + "OPENAI_CHAT_MODEL": "gpt-4", + }, + ): + yield + + +def _make_seed_group(*, value: str, harm_categories: list[str] | None = None) -> SeedAttackGroup: + return SeedAttackGroup(seeds=[SeedObjective(value=value, harm_categories=harm_categories)]) + + +def _make_fake_factory(*, seed_technique=None, adversarial_chat=None) -> MagicMock: + """Return a stub attack-technique factory that produces a fake ``AttackTechnique``. + + Mocks the surface ``AdaptiveScenario._build_techniques_dict`` consumes + (``factory.create(...)`` and ``factory.adversarial_chat``). + """ + fake_technique = MagicMock() + fake_technique.attack = MagicMock(name="fake-attack-strategy") + fake_technique.seed_technique = seed_technique + factory = MagicMock() + factory.create.return_value = fake_technique + factory.adversarial_chat = adversarial_chat + return factory + + +FIXTURES = ["patch_central_database", "mock_runtime_env"] + + +@pytest.mark.usefixtures(*FIXTURES) +class TestTextAdaptiveBasics: + def test_version(self): + assert TextAdaptive.VERSION == 1 + + def test_baseline_forbidden(self): + assert TextAdaptive.BASELINE_POLICY is BaselinePolicy.Forbidden + + def test_default_dataset_config(self): + config = TextAdaptive.default_dataset_config() + assert isinstance(config, DatasetConfiguration) + assert config.max_dataset_size == 4 + + def test_required_datasets_non_empty(self): + assert len(TextAdaptive.required_datasets()) > 0 + + def test_get_strategy_class_is_cached(self): + cls_a = TextAdaptive.get_strategy_class() + cls_b = TextAdaptive.get_strategy_class() + assert cls_a is cls_b + + def test_get_default_strategy(self): + strat = TextAdaptive.get_default_strategy() + # The default aggregate must resolve to something runnable. + assert strat is not None + + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + def test_init_stores_adaptive_params(self, mock_get_scorer, mock_objective_scorer): + mock_get_scorer.return_value = mock_objective_scorer + scenario = TextAdaptive( + epsilon=0.4, + pool_threshold=5, + max_attempts_per_objective=7, + seed=42, + ) + assert scenario._epsilon == 0.4 + assert scenario._pool_threshold == 5 + assert scenario._max_attempts_per_objective == 7 + assert scenario._seed == 42 + + +@pytest.mark.usefixtures(*FIXTURES) +class TestTextAdaptiveAtomicAttacks: + """Tests for ``_get_atomic_attacks_async`` overriding.""" + + async def _build_scenario_and_attacks( + self, + *, + mock_objective_target, + mock_objective_scorer, + seed_groups: dict[str, list[SeedAttackGroup]], + **scenario_kwargs, + ): + with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=seed_groups): + scenario = TextAdaptive( + objective_scorer=mock_objective_scorer, + **scenario_kwargs, + ) + await scenario.initialize_async( + objective_target=mock_objective_target, + include_baseline=False, + ) + return scenario, await scenario._get_atomic_attacks_async() + + async def test_one_atomic_per_objective(self, mock_objective_target, mock_objective_scorer): + groups = { + "violence": [ + _make_seed_group(value="obj-v1", harm_categories=["violence"]), + _make_seed_group(value="obj-v2", harm_categories=["violence"]), + ], + "hate": [ + _make_seed_group(value="obj-h1", harm_categories=["hate"]), + ], + } + _scenario, attacks = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups, + ) + assert len(attacks) == 3 + for atomic in attacks: + # Each atomic carries exactly one seed group. + assert len(atomic.objectives) == 1 + + async def test_atomics_share_one_selector_across_dispatchers(self, mock_objective_target, mock_objective_scorer): + groups = { + "violence": [ + _make_seed_group(value="obj-v1", harm_categories=["violence"]), + _make_seed_group(value="obj-v2", harm_categories=["violence"]), + ], + } + _scenario, attacks = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups, + ) + dispatchers = [atomic._attack_technique.attack for atomic in attacks] + # Each objective gets its own dispatcher (bound to its own seed group)... + assert len({id(d) for d in dispatchers}) == len(attacks) + for d in dispatchers: + assert isinstance(d, AdaptiveDispatchAttack) + # ...but they all share the same selector so learning is global. + selectors = {id(d._selector) for d in dispatchers} + assert len(selectors) == 1 + + async def test_global_context_label_when_using_global_extractor(self, mock_objective_target, mock_objective_scorer): + groups = { + "violence": [_make_seed_group(value="obj-1", harm_categories=["violence"])], + "hate": [_make_seed_group(value="obj-2", harm_categories=["hate"])], + } + _scenario, attacks = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups, + ) + for atomic in attacks: + assert atomic._memory_labels[ADAPTIVE_CONTEXT_LABEL] == GLOBAL_CONTEXT + + async def test_harm_category_extractor_partitions_labels(self, mock_objective_target, mock_objective_scorer): + groups = { + "violence": [_make_seed_group(value="obj-v", harm_categories=["violence"])], + "hate": [_make_seed_group(value="obj-h", harm_categories=["hate"])], + "uncat": [_make_seed_group(value="obj-u", harm_categories=None)], + } + _scenario, attacks = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups, + context_extractor=harm_category_context, + ) + contexts = {atomic._memory_labels[ADAPTIVE_CONTEXT_LABEL] for atomic in attacks} + # Each objective gets its own context bucket from harm_category_context. + assert contexts == {"violence", "hate", "_uncategorized"} + + async def test_atomic_names_are_unique(self, mock_objective_target, mock_objective_scorer): + groups = { + "violence": [_make_seed_group(value=f"obj-{i}", harm_categories=["violence"]) for i in range(5)], + } + _scenario, attacks = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups, + ) + names = [atomic.atomic_attack_name for atomic in attacks] + assert len(set(names)) == len(names) + + async def test_display_group_is_dataset_name(self, mock_objective_target, mock_objective_scorer): + groups = { + "violence": [_make_seed_group(value="obj-v", harm_categories=["violence"])], + "hate": [_make_seed_group(value="obj-h", harm_categories=["hate"])], + } + _scenario, attacks = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups, + ) + display_groups = {atomic.display_group for atomic in attacks} + assert display_groups == {"violence", "hate"} + + async def test_no_usable_techniques_raises(self, mock_objective_target, mock_objective_scorer): + groups = {"violence": [_make_seed_group(value="obj")]} + with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + await scenario.initialize_async( + objective_target=mock_objective_target, + include_baseline=False, + ) + # Force the factory map to be empty. + with patch.object(scenario, "_get_attack_technique_factories", return_value={}): + with pytest.raises(ValueError, match="no usable techniques"): + await scenario._get_atomic_attacks_async() + + async def test_techniques_with_seed_technique_are_kept(self, mock_objective_target, mock_objective_scorer): + """Factories that declare a ``seed_technique`` participate in the pool + (the old behavior silently dropped them with a warning). + """ + groups = {"violence": [_make_seed_group(value="obj")]} + plain_factory = _make_fake_factory(seed_technique=None) + seeded_factory = _make_fake_factory(seed_technique=MagicMock(name="seed_technique")) + + with ( + patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=True), + ): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + with patch.object( + scenario, + "_get_attack_technique_factories", + return_value={"prompt_sending": plain_factory, "many_shot": seeded_factory}, + ): + await scenario.initialize_async( + objective_target=mock_objective_target, + include_baseline=False, + ) + attacks = scenario._atomic_attacks + + assert len(attacks) == 1 + dispatcher = attacks[0]._attack_technique.attack + assert isinstance(dispatcher, AdaptiveDispatchAttack) + # Both factories survive; in particular the seeded one is no longer + # silently dropped. + assert "prompt_sending" in dispatcher._techniques + assert "many_shot" in dispatcher._techniques + + async def test_incompatible_seed_technique_is_filtered_per_objective( + self, mock_objective_target, mock_objective_scorer + ): + """Per-objective candidate pool drops techniques whose seed_technique + is incompatible with the seed group; compatible techniques remain. + """ + groups = {"violence": [_make_seed_group(value="obj")]} + plain_factory = _make_fake_factory(seed_technique=None) + incompatible_factory = _make_fake_factory(seed_technique=MagicMock(name="incompatible_seed_technique")) + + with ( + patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=False), + ): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + with patch.object( + scenario, + "_get_attack_technique_factories", + return_value={"prompt_sending": plain_factory, "many_shot": incompatible_factory}, + ): + await scenario.initialize_async( + objective_target=mock_objective_target, + include_baseline=False, + ) + attacks = scenario._atomic_attacks + + assert len(attacks) == 1 + dispatcher = attacks[0]._attack_technique.attack + # Only the plain technique survives; the seed_technique-bearing one is filtered out + # because is_compatible_with_technique returned False. + assert "prompt_sending" in dispatcher._techniques + assert "many_shot" not in dispatcher._techniques + + async def test_objective_skipped_when_no_compatible_techniques( + self, mock_objective_target, mock_objective_scorer, caplog + ): + """When every technique requires an incompatible seed_technique, the + objective is dropped with a warning rather than producing an atomic + attack with an empty technique pool. + """ + groups = { + "violence": [_make_seed_group(value="obj-keep")], + "hate": [_make_seed_group(value="obj-skip")], + } + seeded_factory = _make_fake_factory(seed_technique=MagicMock(name="seed_technique")) + + # is_compatible_with_technique returns True for "obj-keep", False for "obj-skip". + def _selective_compat(self_group, *, technique): + return self_group.objective.value == "obj-keep" + + with ( + patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object(SeedAttackGroup, "is_compatible_with_technique", _selective_compat), + ): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + with patch.object( + scenario, + "_get_attack_technique_factories", + return_value={"prompt_sending": seeded_factory}, + ): + import logging + + with caplog.at_level(logging.WARNING): + await scenario.initialize_async( + objective_target=mock_objective_target, + include_baseline=False, + ) + attacks = scenario._atomic_attacks + + # Only the compatible objective produced an atomic attack. + assert len(attacks) == 1 + # Skip was logged with the affected objective value. + assert any("obj-skip" in record.getMessage() for record in caplog.records) + + +@pytest.mark.usefixtures(*FIXTURES) +class TestTextAdaptiveSelectorRehydration: + """When resuming, prior dispatch trails should replay into the new selector.""" + + def _build_scenario_no_resume_id(self, *, scorer): + return TextAdaptive(objective_scorer=scorer) + + def test_no_scenario_result_id_is_noop(self, mock_objective_scorer): + from pyrit.scenario.scenarios.adaptive.selector import AdaptiveTechniqueSelector + + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + selector = AdaptiveTechniqueSelector() + # No scenario_result_id set -> early return, no errors, no replays. + scenario._rehydrate_selector_from_memory(selector=selector, known_techniques={"a", "b"}) + assert selector.snapshot() == {} + + def test_replays_attempts_from_metadata(self, mock_objective_scorer): + from pyrit.models import AttackResult + from pyrit.scenario.scenarios.adaptive.selector import AdaptiveTechniqueSelector + + scenario = TextAdaptive(objective_scorer=mock_objective_scorer, scenario_result_id="rid") + + prior_result = MagicMock() + prior_result.attack_results = { + "adaptive_violence_o1": [ + AttackResult( + conversation_id="c1", + objective="o1", + metadata={ + "adaptive_attempts": [ + {"technique": "a", "outcome": "failure"}, + {"technique": "b", "outcome": "success"}, + ], + "adaptive_context": "violence", + }, + ), + ], + "adaptive_hate_o2": [ + AttackResult( + conversation_id="c2", + objective="o2", + metadata={ + "adaptive_attempts": [{"technique": "a", "outcome": "success"}], + "adaptive_context": "hate", + }, + ), + ], + } + + selector = AdaptiveTechniqueSelector() + with patch.object(scenario._memory, "get_scenario_results", return_value=[prior_result]): + scenario._rehydrate_selector_from_memory(selector=selector, known_techniques={"a", "b"}) + + # Trails replayed verbatim into the per-context table. + assert selector.counts(context="violence", technique="a") == (0, 1) + assert selector.counts(context="violence", technique="b") == (1, 1) + assert selector.counts(context="hate", technique="a") == (1, 1) + + def test_skips_unknown_techniques(self, mock_objective_scorer): + from pyrit.models import AttackResult + from pyrit.scenario.scenarios.adaptive.selector import AdaptiveTechniqueSelector + + scenario = TextAdaptive(objective_scorer=mock_objective_scorer, scenario_result_id="rid") + prior_result = MagicMock() + prior_result.attack_results = { + "x": [ + AttackResult( + conversation_id="c1", + objective="o1", + metadata={ + "adaptive_attempts": [ + {"technique": "removed_technique", "outcome": "success"}, + {"technique": "a", "outcome": "failure"}, + ], + "adaptive_context": "ctx", + }, + ), + ], + } + + selector = AdaptiveTechniqueSelector() + with patch.object(scenario._memory, "get_scenario_results", return_value=[prior_result]): + scenario._rehydrate_selector_from_memory(selector=selector, known_techniques={"a"}) + + # Only the known technique was recorded. + assert selector.counts(context="ctx", technique="a") == (0, 1) + assert selector.counts(context="ctx", technique="removed_technique") == (0, 0) + + def test_ignores_results_without_adaptive_metadata(self, mock_objective_scorer): + from pyrit.models import AttackResult + from pyrit.scenario.scenarios.adaptive.selector import AdaptiveTechniqueSelector + + scenario = TextAdaptive(objective_scorer=mock_objective_scorer, scenario_result_id="rid") + prior_result = MagicMock() + prior_result.attack_results = { + "baseline": [AttackResult(conversation_id="c", objective="o", metadata={})], + } + + selector = AdaptiveTechniqueSelector() + with patch.object(scenario._memory, "get_scenario_results", return_value=[prior_result]): + scenario._rehydrate_selector_from_memory(selector=selector, known_techniques={"a"}) + assert selector.snapshot() == {} + + def test_memory_load_failure_is_swallowed(self, mock_objective_scorer): + from pyrit.scenario.scenarios.adaptive.selector import AdaptiveTechniqueSelector + + scenario = TextAdaptive(objective_scorer=mock_objective_scorer, scenario_result_id="rid") + + selector = AdaptiveTechniqueSelector() + with patch.object(scenario._memory, "get_scenario_results", side_effect=RuntimeError("db down")): + # Must not raise; selector remains empty. + scenario._rehydrate_selector_from_memory(selector=selector, known_techniques={"a"}) + assert selector.snapshot() == {} + + +@pytest.mark.usefixtures(*FIXTURES) +class TestTextAdaptiveBaselinePolicy: + async def test_initialize_async_rejects_explicit_baseline(self, mock_objective_target, mock_objective_scorer): + groups = {"violence": [_make_seed_group(value="obj", harm_categories=["violence"])]} + with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + with pytest.raises(ValueError): + await scenario.initialize_async( + objective_target=mock_objective_target, + include_baseline=True, + ) From a151bedb6a1681ca66926654990708ba8a5ef973 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:32:56 -0700 Subject: [PATCH 07/42] MAINT: migrate adaptive scenario onto StrategyGraph (Phase 6b) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces `AdaptiveStep(ScenarioStep)` as the per-objective execution unit and migrates `AdaptiveScenario` to dispatch through `StrategyGraph`. The new step extracts the per-objective adaptive loop from `AdaptiveDispatchAttack._perform_async` and emits `ScenarioStepResult` with outcome label `'success'` or `'exhausted'` (lifting the static `'done'` outcome). It duck-types the `AtomicAttack`-like attributes (`atomic_attack_name`, `objectives`, `seed_groups`, `display_group`, `filter_seed_groups_by_objectives`) so the orchestrator's resume bookkeeping continues to work without changes. `AdaptiveScenario` now overrides `_build_execution_graph` with a custom linear policy (`_build_adaptive_linear_policy`) that always dispatches via `step.process_async()` — bypassing the base class's `isinstance(_step, AtomicAttack)` branch that would otherwise flatten outcomes to `'done'`. The scenario caches its single `AdaptiveTechniqueSelector` on `self._selector` for external introspection and shares the same reference across every emitted `AdaptiveStep`. `AdaptiveDispatchAttack` is deprecated via `print_deprecation_message` pointing to `AdaptiveStep`; scheduled for removal in 0.17.0. Module docstring updated accordingly. Tests: adds `tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py` (19 tests across init validation, AtomicAttack parity, process loop, identifier shape, adaptive-context labels). Migrates 3 assertions in `test_text_adaptive.py` (selector sharing, seed-technique compat) to introspect `step._techniques`/`step._selector` directly. Suppresses dispatcher deprecation noise via module-level `pytestmark` in `test_dispatcher.py` and adds a dedicated `TestDeprecation` class that explicitly asserts the warning fires. Adaptive package: 83 tests pass (was 64). Full unit suite: 7984 passed (no regressions outside the pre-existing ODBC env failure in test_pyrit_scan.py). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/scenarios/adaptive/__init__.py | 2 + .../scenarios/adaptive/adaptive_scenario.py | 171 +++++-- .../scenarios/adaptive/adaptive_step.py | 349 +++++++++++++++ .../scenario/scenarios/adaptive/dispatcher.py | 22 +- .../scenarios/adaptive/test_adaptive_step.py | 420 ++++++++++++++++++ .../scenarios/adaptive/test_dispatcher.py | 20 + .../scenarios/adaptive/test_text_adaptive.py | 27 +- 7 files changed, 956 insertions(+), 55 deletions(-) create mode 100644 pyrit/scenario/scenarios/adaptive/adaptive_step.py create mode 100644 tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py diff --git a/pyrit/scenario/scenarios/adaptive/__init__.py b/pyrit/scenario/scenarios/adaptive/__init__.py index d0bd978c2..bd37235fb 100644 --- a/pyrit/scenario/scenarios/adaptive/__init__.py +++ b/pyrit/scenario/scenarios/adaptive/__init__.py @@ -4,6 +4,7 @@ """Adaptive scenario classes.""" from pyrit.scenario.scenarios.adaptive.adaptive_scenario import AdaptiveScenario +from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep from pyrit.scenario.scenarios.adaptive.dispatcher import ( ADAPTIVE_CONTEXT_LABEL, AdaptiveDispatchAttack, @@ -20,6 +21,7 @@ "ADAPTIVE_CONTEXT_LABEL", "AdaptiveDispatchAttack", "AdaptiveScenario", + "AdaptiveStep", "AdaptiveTechniqueSelector", "ContextExtractor", "TextAdaptive", diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py index 723849ce9..9b0588e01 100644 --- a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -19,15 +19,15 @@ import logging import random import uuid -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, ClassVar, cast from pyrit.executor.attack import AttackScoringConfig -from pyrit.scenario.core.atomic_attack import AtomicAttack -from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.scenario import BaselinePolicy, Scenario +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy +from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep from pyrit.scenario.scenarios.adaptive.dispatcher import ( ADAPTIVE_CONTEXT_LABEL, - AdaptiveDispatchAttack, TechniqueBundle, ) from pyrit.scenario.scenarios.adaptive.selector import ( @@ -37,8 +37,11 @@ ) if TYPE_CHECKING: + from collections.abc import Sequence + from pyrit.models import SeedAttackGroup from pyrit.prompt_target import PromptTarget + from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.score import TrueFalseScorer logger = logging.getLogger(__name__) @@ -96,6 +99,9 @@ def __init__( self._max_attempts_per_objective = max_attempts_per_objective self._seed = seed self._context_extractor = context_extractor + # Populated by _get_atomic_attacks_async; consumed by _build_execution_graph + # only when an override path needs to introspect it externally. + self._selector: AdaptiveTechniqueSelector | None = None super().__init__( version=self.VERSION, @@ -106,18 +112,27 @@ def __init__( async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ - Build one ``AtomicAttack`` per objective. - - Each objective gets a freshly constructed ``AdaptiveDispatchAttack`` - bound to its seed group, but all dispatchers share the same selector - so learning accumulates across objectives. Per-objective, techniques - whose ``seed_technique`` is incompatible with the seed group are - filtered out; objectives left with no compatible techniques are skipped. + Build one :class:`AdaptiveStep` per objective. + + Each objective gets a freshly constructed step bound to its seed group, + but all steps share the same selector so learning accumulates across + objectives. Per-objective, techniques whose ``seed_technique`` is + incompatible with the seed group are filtered out; objectives left + with no compatible techniques are skipped. + + The return type is :class:`list[AtomicAttack]` for parity with the base + ``Scenario._get_atomic_attacks_async`` contract — the orchestrator's + resume bookkeeping treats steps via the duck-typed attributes + :class:`AdaptiveStep` provides (``atomic_attack_name``, ``objectives``, + ``seed_groups``, ``display_group``, ``filter_seed_groups_by_objectives``). + Execution dispatch in :meth:`_build_execution_graph` calls + ``step.process_async`` directly, bypassing the default linear policy's + ``AtomicAttack.run_async`` branch. Returns: - list[AtomicAttack]: One ``AtomicAttack`` per objective with at - least one compatible technique. Empty if every seed group - is incompatible with every selected technique. + list[AtomicAttack]: One step per objective with at least one + compatible technique. Empty if every seed group is incompatible + with every selected technique. Raises: ValueError: If ``self._objective_target`` is not set, or if @@ -135,21 +150,23 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: ) # On resume, replay prior attempt outcomes from persisted metadata. self._rehydrate_selector_from_memory(selector=selector, known_techniques=set(techniques)) + # Cache the selector so _build_execution_graph reuses the rehydrated state. + self._selector = selector seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() - atomic_attacks: list[AtomicAttack] = [] + steps: list[AdaptiveStep] = [] for dataset_name, seed_groups in seed_groups_by_dataset.items(): for seed_group in seed_groups: - atomic = self._build_atomic_for_seed_group( + step = self._build_step_for_seed_group( dataset_name=dataset_name, seed_group=seed_group, techniques=techniques, selector=selector, ) - if atomic is not None: - atomic_attacks.append(atomic) + if step is not None: + steps.append(step) - return atomic_attacks + return cast("list[AtomicAttack]", steps) def _build_techniques_dict( self, @@ -202,24 +219,24 @@ def _build_techniques_dict( return techniques - def _build_atomic_for_seed_group( + def _build_step_for_seed_group( self, *, dataset_name: str, seed_group: SeedAttackGroup, techniques: dict[str, TechniqueBundle], selector: AdaptiveTechniqueSelector, - ) -> AtomicAttack | None: + ) -> AdaptiveStep | None: """ - Build a single ``AtomicAttack`` for one ``SeedAttackGroup``. + Build a single :class:`AdaptiveStep` for one ``SeedAttackGroup``. Filters the technique pool down to those whose ``seed_technique`` (if - any) is compatible with this seed group, then constructs a dedicated - ``AdaptiveDispatchAttack`` bound to this seed group. + any) is compatible with this seed group, then constructs an + :class:`AdaptiveStep` bound to it. Returns: - AtomicAttack | None: The constructed atomic attack, or ``None`` when - no techniques are compatible (caller skips the objective). + AdaptiveStep | None: The constructed step, or ``None`` when no + techniques are compatible (caller skips the objective). Raises: ValueError: If ``self._objective_target`` is not set (defensive @@ -248,26 +265,104 @@ def _build_atomic_for_seed_group( objective_id = seed_group.objective.id if seed_group.objective.id else uuid.uuid4() atomic_attack_name = f"{self._atomic_attack_prefix}_{dataset_name}_{objective_id}" - dispatcher = AdaptiveDispatchAttack( + memory_labels = { + **self._memory_labels, + ADAPTIVE_CONTEXT_LABEL: adaptive_context, + } + return AdaptiveStep( + atomic_attack_name=atomic_attack_name, + display_group=dataset_name, objective_target=self._objective_target, techniques=compatible, selector=selector, seed_group=seed_group, objective_scorer=self._objective_scorer, max_attempts_per_objective=self._max_attempts_per_objective, + memory_labels=memory_labels, + adaptive_context=adaptive_context, ) - memory_labels = { - **self._memory_labels, - ADAPTIVE_CONTEXT_LABEL: adaptive_context, - } - return AtomicAttack( - atomic_attack_name=atomic_attack_name, - attack_technique=AttackTechnique(attack=dispatcher), - seed_groups=[seed_group], - objective_scorer=self._objective_scorer, - memory_labels=memory_labels, - display_group=dataset_name, + def _build_execution_graph( + self, + *, + steps: Sequence[ScenarioStep] | None = None, + ) -> StrategyGraph[ScenarioStep, int]: + """ + Build a linear graph that drives each :class:`AdaptiveStep` via + ``process_async`` so the ``"success"`` / ``"exhausted"`` outcome + labels survive into the orchestrator (the default policy from the + base class would dispatch ``AtomicAttack`` instances through + ``run_async`` and lose the outcome distinction). + + Args: + steps: Optional explicit step list. Defaults to + ``self._atomic_attacks``, mirroring the base class contract. + + Returns: + StrategyGraph[ScenarioStep, int]: A linear traversal whose actions + always dispatch via ``step.process_async``. + """ + effective_steps = list(steps) if steps is not None else list(self._atomic_attacks) + policy = self._build_adaptive_linear_policy(steps=effective_steps) + return StrategyGraph(policy=policy) + + def _build_adaptive_linear_policy( + self, + *, + steps: Sequence[ScenarioStep], + ) -> StrategyPolicy[ScenarioStep, int]: + """ + Build a linear policy that always dispatches via ``process_async``. + + Each policy action runs ``steps[i].process_async()`` and transitions + to state ``i + 1``; state ``len(steps)`` is the sole terminal state. + Unlike :meth:`Scenario._build_default_linear_policy` there's no + ``isinstance(_step, AtomicAttack)`` branch — adaptive steps always + go through their own process_async loop so the + ``"success"``/``"exhausted"`` outcome labels propagate unchanged. + + Args: + steps: The steps to wrap. Must be non-empty. + + Returns: + StrategyPolicy[ScenarioStep, int]: A frozen linear policy. + + Raises: + ValueError: If ``steps`` is empty. + """ + if not steps: + raise ValueError("_build_adaptive_linear_policy requires at least one step.") + + terminal_state = len(steps) + actions: dict[int, PolicyAction[ScenarioStep, int]] = {} + + for index, step in enumerate(steps): + + async def _action( + graph: StrategyGraph[ScenarioStep, int], + _step: ScenarioStep = step, + _next: int = index + 1, + ) -> tuple[int, ScenarioStepResult | None]: + graph.bind_current_step(step=_step) + try: + base_result = await _step.process_async() + merged_metadata = {"step_name": _step.name, **base_result.metadata} + result: ScenarioStepResult | None = ScenarioStepResult( + outcome=base_result.outcome, + attack_results=list(base_result.attack_results), + step_identifier=base_result.step_identifier, + metadata=merged_metadata, + ) + finally: + graph.bind_current_step(step=None) + return _next, result + + actions[index] = _action + + return StrategyPolicy( + actions=actions, + initial_state=0, + terminal_states=frozenset({terminal_state}), ) def _rehydrate_selector_from_memory( diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_step.py b/pyrit/scenario/scenarios/adaptive/adaptive_step.py new file mode 100644 index 000000000..3bb39d581 --- /dev/null +++ b/pyrit/scenario/scenarios/adaptive/adaptive_step.py @@ -0,0 +1,349 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``AdaptiveStep`` — per-objective adaptive ``ScenarioStep`` (Phase 6b). + +Replaces ``AdaptiveDispatchAttack`` as the per-objective execution unit for +adaptive scenarios. Where the dispatcher was an ``AttackStrategy`` wrapped +inside an ``AtomicAttack`` (which itself sits inside a flat scenario loop), +the step plugs directly into the ``StrategyGraph`` event loop introduced in +Phase 5: + +* ``process_async`` runs the per-objective selector loop and returns a + ``ScenarioStepResult`` whose ``outcome`` is ``"success"`` or ``"exhausted"`` + — these become real transition labels the surrounding policy can branch on. +* The dispatch trail (``adaptive_attempts`` + ``adaptive_context``) is stamped + onto the returned ``AttackResult.metadata`` exactly as the dispatcher did, + so existing persistence / rehydration semantics are preserved. +* Inner-attack execution still flows through ``AttackExecutor`` against a + technique-merged ``SeedAttackGroup``, so the inner ``AttackResult`` rows the + techniques persist are unchanged. + +``AdaptiveStep`` provides the ``AtomicAttack``-shaped attributes that the +``Scenario`` orchestrator uses for resume bookkeeping +(``atomic_attack_name``, ``objectives``, ``seed_groups``, ``display_group``, +``filter_seed_groups_by_objectives``) without subclassing ``AtomicAttack``, +because doing so would force a single canonical ``AttackTechnique`` onto the +parent constructor that doesn't match the multi-technique reality. +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import replace +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +from pyrit.executor.attack.core.attack_executor import AttackExecutor +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import AttackOutcome, AttackResult +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.scenarios.adaptive.dispatcher import ( + ADAPTIVE_ATTEMPT_LABEL, + ADAPTIVE_TECHNIQUE_LABEL, + TechniqueBundle, +) +from pyrit.scenario.scenarios.adaptive.selector import ( + GLOBAL_CONTEXT, + AdaptiveTechniqueSelector, +) + +if TYPE_CHECKING: + from pyrit.models import SeedAttackGroup + from pyrit.prompt_target import PromptTarget + from pyrit.score import TrueFalseScorer + +logger = logging.getLogger(__name__) + + +class AdaptiveStep(ScenarioStep): + """ + Per-objective adaptive selector step. + + Owns the per-objective adaptive loop that PR #1760 placed inside + ``AdaptiveDispatchAttack``. Each call to :meth:`process_async` selects up + to ``max_attempts_per_objective`` techniques via the shared + :class:`AdaptiveTechniqueSelector`, runs each selected technique against + the bound seed group, records the outcome on the selector, and stops on + first success. + + The returned :class:`ScenarioStepResult` exposes the outcome as a real + transition label (``"success"`` or ``"exhausted"``) so a state-machine + policy can route on it. A single ``AttackResult`` is returned — a fresh + dispatcher-owned copy of the last inner result with the adaptive trail + stamped onto ``metadata`` — preserving the two-row persistence story + (inner row + outer row sharing a conversation id) described on + :class:`AdaptiveDispatchAttack`. + """ + + #: Transition labels this step can emit. ``"success"`` means at least one + #: attempt achieved :attr:`AttackOutcome.SUCCESS`; ``"exhausted"`` means + #: ``max_attempts_per_objective`` ran without success. + _OUTPUTS: tuple[str, ...] = ("success", "exhausted") + + def __init__( + self, + *, + atomic_attack_name: str, + display_group: str | None = None, + objective_target: PromptTarget, + techniques: dict[str, TechniqueBundle], + selector: AdaptiveTechniqueSelector, + seed_group: SeedAttackGroup, + objective_scorer: TrueFalseScorer | None = None, + max_attempts_per_objective: int = 3, + memory_labels: dict[str, str] | None = None, + adaptive_context: str = GLOBAL_CONTEXT, + ) -> None: + """ + Args: + atomic_attack_name: Unique key used by the scenario for resume tracking + and result persistence. Mirrors :attr:`AtomicAttack.atomic_attack_name`. + display_group: Optional label for grouping results in user-facing output. + Defaults to ``atomic_attack_name`` when ``None``. + objective_target: The target inner attacks run against. Stored for + identifier / logging parity; not called directly by the step. + techniques: Mapping ``{technique_name: TechniqueBundle}``. Must be non-empty. + selector: Shared :class:`AdaptiveTechniqueSelector` so learning accumulates + across all per-objective steps in the scenario. + seed_group: The :class:`SeedAttackGroup` this step is bound to. Each attempt + merges the chosen technique's ``seed_technique`` (if any) into this group + before execution. + objective_scorer: Scorer passed through to techniques that generate + simulated conversations. + max_attempts_per_objective: Max techniques per objective; must be ``>= 1``. + Defaults to 3. + memory_labels: Per-attempt memory labels (typically including the + ``ADAPTIVE_CONTEXT_LABEL`` stamped by the scenario). + adaptive_context: The selector context bucket for this step (e.g. + :data:`GLOBAL_CONTEXT` or a harm-category key). Stays stable for the + lifetime of the step. + + Raises: + ValueError: If ``techniques`` is empty or ``max_attempts_per_objective`` < 1. + """ + if not techniques: + raise ValueError("techniques must contain at least one attack technique") + if max_attempts_per_objective < 1: + raise ValueError(f"max_attempts_per_objective must be >= 1, got {max_attempts_per_objective}") + + self.atomic_attack_name = atomic_attack_name + self.display_group = display_group or atomic_attack_name + + self._objective_target = objective_target + self._techniques = techniques + self._selector = selector + self._seed_group = seed_group + self._objective_scorer = objective_scorer + self._max_attempts = max_attempts_per_objective + self._memory_labels = memory_labels or {} + self._adaptive_context = adaptive_context + # Attempts are inherently sequential (each one reads the selector + # state updated by the previous), so a single shared executor with + # ``max_concurrency=1`` is reused across attempts. + self._executor = AttackExecutor(max_concurrency=1) + + @property + def name(self) -> str: + """Display / resume key. Aliases :attr:`atomic_attack_name`.""" + return self.atomic_attack_name + + @property + def outputs(self) -> list[str]: + """Transition labels this step can emit.""" + return list(self._OUTPUTS) + + @property + def objectives(self) -> list[str]: + """Objectives drawn from the bound seed group (parity with ``AtomicAttack``).""" + if self._seed_group.objective is not None: + return [self._seed_group.objective.value] + return [] + + @property + def seed_groups(self) -> list[SeedAttackGroup]: + """One-element list view of the bound seed group (parity with ``AtomicAttack``).""" + return [self._seed_group] + + def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) -> None: + """ + Drop the bound seed group when its objective is already complete. + + Mirrors :meth:`AtomicAttack.filter_seed_groups_by_objectives` so the + scenario's resume filter (which calls this on every step it iterates + over) works uniformly across step types. Because an adaptive step is + bound to a single seed group, this collapses to a presence check on + that single objective. + + Args: + remaining_objectives: Objectives that still need execution. + """ + if self._seed_group.objective is None: + return + if self._seed_group.objective.value not in set(remaining_objectives): + # Replace the bound seed group with a marker that has no objective + # so subsequent process_async calls are no-ops. In practice the + # orchestrator skips the step entirely when objectives is empty, + # so this branch is rarely hit; kept for defensive parity. + self._seed_group = self._seed_group # noqa: PLW0127 — explicit no-op + + async def process_async(self) -> ScenarioStepResult: + """ + Run the per-objective adaptive loop and return its outcome. + + Loops up to ``max_attempts_per_objective`` times: select a technique + via the shared selector, execute it against the merged seed group, + record the outcome on the selector, and stop early on first success. + The returned :class:`ScenarioStepResult` carries one + :class:`AttackResult` — a fresh dispatcher-owned copy of the final + inner result with the dispatch trail stamped onto ``metadata``. + + Returns: + ScenarioStepResult: ``outcome`` is ``"success"`` if any attempt + succeeded, ``"exhausted"`` otherwise. ``attack_results`` holds + the single outer result; ``metadata`` carries the step name, + the dispatch trail, and the adaptive context. + + Raises: + RuntimeError: If the loop somehow ran zero attempts (unreachable + because ``max_attempts_per_objective`` is validated >= 1). + """ + technique_names = list(self._techniques.keys()) + last_result: AttackResult | None = None + trail: list[dict[str, str]] = [] + succeeded = False + + for attempt_idx in range(self._max_attempts): + chosen = self._selector.select(context=self._adaptive_context, techniques=technique_names) + bundle = self._techniques[chosen] + attempt_labels = { + **self._memory_labels, + ADAPTIVE_TECHNIQUE_LABEL: chosen, + ADAPTIVE_ATTEMPT_LABEL: str(attempt_idx + 1), + } + + logger.debug( + "AdaptiveStep[%s]: attempt %d/%d context=%r technique=%r", + self.atomic_attack_name, + attempt_idx + 1, + self._max_attempts, + self._adaptive_context, + chosen, + ) + + result = await self._run_inner_attack_async(bundle=bundle, attempt_labels=attempt_labels) + success = result.outcome == AttackOutcome.SUCCESS + self._selector.record_outcome( + context=self._adaptive_context, + technique=chosen, + success=success, + ) + + trail.append({"technique": chosen, "outcome": result.outcome.value}) + last_result = result + + if success: + succeeded = True + break + + if last_result is None: # pragma: no cover - defensive + raise RuntimeError("AdaptiveStep ran zero attempts; this should be unreachable.") + + outcome_label = "success" if succeeded else "exhausted" + # Return a fresh outer ``AttackResult`` to avoid PK conflicts with the + # inner attack's already-persisted row (see ``AdaptiveDispatchAttack`` + # for the two-row persistence rationale). + outer_result = replace( + last_result, + attack_result_id=str(uuid.uuid4()), + timestamp=datetime.now(timezone.utc), + metadata={ + **last_result.metadata, + "adaptive_attempts": trail, + "adaptive_context": self._adaptive_context, + }, + ) + + return ScenarioStepResult( + outcome=outcome_label, + attack_results=[outer_result], + metadata={ + "step_name": self.atomic_attack_name, + "adaptive_attempts": trail, + "adaptive_context": self._adaptive_context, + }, + ) + + async def _run_inner_attack_async( + self, + *, + bundle: TechniqueBundle, + attempt_labels: dict[str, str], + ) -> AttackResult: + """ + Execute the chosen technique against the bound seed group. + + Merges ``bundle.seed_technique`` into the bound ``seed_group`` (when + present) and delegates execution to :class:`AttackExecutor`. Isolated + as a method so tests can patch the inner-attack call surface. + + Args: + bundle: The chosen technique's attack + seeds + chat. + attempt_labels: Memory labels stamped onto this attempt. + + Returns: + AttackResult: The single result produced for this attempt. + + Raises: + RuntimeError: If the executor returned no completed results and no + propagated exception (should be unreachable). + """ + if bundle.seed_technique is not None: + execution_group = self._seed_group.with_technique(technique=bundle.seed_technique) + else: + execution_group = self._seed_group + + executor_result = await self._executor.execute_attack_from_seed_groups_async( + attack=bundle.attack, + seed_groups=[execution_group], + adversarial_chat=bundle.adversarial_chat, + objective_scorer=self._objective_scorer, + memory_labels=attempt_labels, + ) + + if executor_result.completed_results: + return executor_result.completed_results[0] + if executor_result.incomplete_objectives: + raise executor_result.incomplete_objectives[0][1] + raise RuntimeError( # pragma: no cover - defensive + "AttackExecutor returned neither completed nor incomplete results." + ) + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for this adaptive step. + + Captures the step name, declared outputs, ``max_attempts`` configuration, + and adaptive context. Each technique's bundled attack identifier is + nested under ``children["techniques"]`` (sorted by technique name for + hash stability) so drift in any inner attack or its seeds propagates + upward. + + Returns: + ComponentIdentifier: The frozen identity snapshot. + """ + technique_ids: list[ComponentIdentifier] = [ + bundle.attack.get_identifier() for _, bundle in sorted(self._techniques.items()) + ] + return ComponentIdentifier.of( + self, + params={ + "atomic_attack_name": self.atomic_attack_name, + "outputs": list(self.outputs), + "max_attempts_per_objective": self._max_attempts, + "adaptive_context": self._adaptive_context, + }, + children={"techniques": technique_ids}, + ) diff --git a/pyrit/scenario/scenarios/adaptive/dispatcher.py b/pyrit/scenario/scenarios/adaptive/dispatcher.py index 46808bfde..36c2be762 100644 --- a/pyrit/scenario/scenarios/adaptive/dispatcher.py +++ b/pyrit/scenario/scenarios/adaptive/dispatcher.py @@ -2,14 +2,22 @@ # Licensed under the MIT license. """ -``AdaptiveDispatchAttack`` — picks an inner technique per attempt via an -``AdaptiveTechniqueSelector``, runs it, records the outcome, and loops up to -``max_attempts_per_objective`` times. Reads the per-objective context key from +``AdaptiveDispatchAttack`` — **deprecated** as of 0.15.0; use +:class:`pyrit.scenario.scenarios.adaptive.AdaptiveStep` instead. + +Picks an inner technique per attempt via an ``AdaptiveTechniqueSelector``, +runs it, records the outcome, and loops up to ``max_attempts_per_objective`` +times. Reads the per-objective context key from ``context.memory_labels[ADAPTIVE_CONTEXT_LABEL]`` (falls back to the global context). The dispatcher is bound to a single ``SeedAttackGroup`` at construction time so it can merge each chosen technique's ``seed_technique`` (when present) into the seed group before delegating execution to ``AttackExecutor``. + +Scheduled for removal in 0.17.0 once existing callers migrate to +``AdaptiveStep``, which drives the same per-objective loop through the +``StrategyGraph`` event loop and emits ``"success"`` / ``"exhausted"`` as +real transition labels. """ from __future__ import annotations @@ -20,6 +28,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any +from pyrit.common.deprecation import print_deprecation_message from pyrit.executor.attack.core.attack_executor import AttackExecutor from pyrit.executor.attack.core.attack_parameters import AttackParameters from pyrit.executor.attack.core.attack_strategy import AttackContext, AttackStrategy @@ -125,6 +134,13 @@ def __init__( if max_attempts_per_objective < 1: raise ValueError(f"max_attempts_per_objective must be >= 1, got {max_attempts_per_objective}") + print_deprecation_message( + old_item="AdaptiveDispatchAttack", + new_item="pyrit.scenario.scenarios.adaptive.AdaptiveStep (drives the same loop " + "through StrategyGraph; emits real success/exhausted transition labels)", + removed_in="0.17.0", + ) + super().__init__( objective_target=objective_target, context_type=AdaptiveDispatchContext, diff --git a/tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py b/tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py new file mode 100644 index 000000000..a80d380bd --- /dev/null +++ b/tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py @@ -0,0 +1,420 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for ``AdaptiveStep`` — the StrategyGraph-native replacement for AdaptiveDispatchAttack.""" + +from __future__ import annotations + +import random +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup, SeedObjective +from pyrit.scenario.core.scenario_step import ScenarioStepResult +from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep +from pyrit.scenario.scenarios.adaptive.dispatcher import ( + ADAPTIVE_ATTEMPT_LABEL, + ADAPTIVE_CONTEXT_LABEL, + ADAPTIVE_TECHNIQUE_LABEL, + TechniqueBundle, +) +from pyrit.scenario.scenarios.adaptive.selector import ( + GLOBAL_CONTEXT, + AdaptiveTechniqueSelector, +) + + +def _make_bundle(*, name: str, outcomes: list[AttackOutcome], seed_technique=None) -> TechniqueBundle: + """Build a TechniqueBundle whose attack stub yields the given outcomes in order. + + The step routes execution through ``_run_inner_attack_async``; tests + patch that method directly so we only need a placeholder attack here. + """ + attack = MagicMock(name=f"attack-{name}") + attack._outcomes = outcomes + attack._name = name + return TechniqueBundle(attack=attack, seed_technique=seed_technique) + + +def _patch_inner( + *, + step: AdaptiveStep, + bundles: dict[str, TechniqueBundle], +) -> AsyncMock: + """Replace ``_run_inner_attack_async`` with a stub backed by per-bundle outcomes. + + Returns the AsyncMock so tests can introspect call history (kwargs include + ``bundle`` and ``attempt_labels``). + """ + name_for_attack = {id(b.attack): name for name, b in bundles.items()} + counters: dict[str, int] = dict.fromkeys(bundles, 0) + + async def _stub(*, bundle: TechniqueBundle, attempt_labels: dict[str, str]) -> AttackResult: + name = name_for_attack[id(bundle.attack)] + idx = counters[name] + counters[name] = idx + 1 + outcome = bundle.attack._outcomes[idx] + return AttackResult( + conversation_id=f"conv-{name}-{idx}", + objective="obj", + outcome=outcome, + ) + + inner_mock = AsyncMock(side_effect=_stub) + step._run_inner_attack_async = inner_mock # type: ignore[method-assign] + return inner_mock + + +@pytest.fixture +def selector() -> AdaptiveTechniqueSelector: + # epsilon=0 makes selection deterministic given the table. + return AdaptiveTechniqueSelector(epsilon=0.0, pool_threshold=1, rng=random.Random(0)) + + +@pytest.fixture +def target() -> MagicMock: + return MagicMock(name="objective_target") + + +@pytest.fixture +def seed_group() -> SeedAttackGroup: + return SeedAttackGroup(seeds=[SeedObjective(value="obj")]) + + +@pytest.mark.usefixtures("patch_central_database") +class TestInit: + def test_init_rejects_empty_techniques(self, target, selector, seed_group): + with pytest.raises(ValueError, match="techniques"): + AdaptiveStep( + atomic_attack_name="step-1", + objective_target=target, + techniques={}, + selector=selector, + seed_group=seed_group, + ) + + @pytest.mark.parametrize("bad_max", [0, -1]) + def test_init_rejects_invalid_max_attempts(self, target, selector, seed_group, bad_max): + with pytest.raises(ValueError, match="max_attempts_per_objective"): + AdaptiveStep( + atomic_attack_name="step-1", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=bad_max, + ) + + def test_display_group_defaults_to_name(self, target, selector, seed_group): + step = AdaptiveStep( + atomic_attack_name="step-x", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + assert step.display_group == "step-x" + + def test_explicit_display_group_overrides_default(self, target, selector, seed_group): + step = AdaptiveStep( + atomic_attack_name="step-x", + display_group="my-group", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + assert step.display_group == "my-group" + + def test_outputs_are_success_and_exhausted(self, target, selector, seed_group): + step = AdaptiveStep( + atomic_attack_name="step-x", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + assert step.outputs == ["success", "exhausted"] + assert step.name == "step-x" + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackParity: + """AdaptiveStep must expose AtomicAttack-like attributes for the resume filter.""" + + def test_objectives_drawn_from_seed_group(self, target, selector, seed_group): + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + assert step.objectives == ["obj"] + + def test_seed_groups_is_single_element_list(self, target, selector, seed_group): + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + assert step.seed_groups == [seed_group] + + def test_filter_seed_groups_by_objectives_is_noop_when_matched(self, target, selector, seed_group): + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + # Should not raise; bound seed group's objective ("obj") is in the list. + step.filter_seed_groups_by_objectives(remaining_objectives=["obj"]) + assert step.seed_groups == [seed_group] + + +@pytest.mark.usefixtures("patch_central_database") +class TestProcess: + async def test_stops_on_first_success(self, target, selector, seed_group): + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS]), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]), + } + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=5, + ) + inner = _patch_inner(step=step, bundles=bundles) + + result = await step.process_async() + + assert isinstance(result, ScenarioStepResult) + assert result.outcome == "success" + assert inner.call_count == 1 + assert len(result.attack_results) == 1 + assert result.attack_results[0].outcome == AttackOutcome.SUCCESS + + async def test_retries_until_max_attempts_on_failure_emits_exhausted(self, target, selector, seed_group): + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.FAILURE] * 3), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.FAILURE] * 3), + } + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=3, + ) + inner = _patch_inner(step=step, bundles=bundles) + + result = await step.process_async() + + assert result.outcome == "exhausted" + assert inner.call_count == 3 + + async def test_updates_selector_on_each_attempt(self, target, selector, seed_group): + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.FAILURE, AttackOutcome.SUCCESS]), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]), + } + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=3, + ) + inner = _patch_inner(step=step, bundles=bundles) + + await step.process_async() + + total_attempts = sum(selector.counts(context=GLOBAL_CONTEXT, technique=t)[1] for t in ("a", "b")) + assert total_attempts == inner.call_count + + async def test_passes_attempt_labels_to_inner(self, target, selector, seed_group): + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])} + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + memory_labels={"foo": "bar"}, + ) + inner = _patch_inner(step=step, bundles=bundles) + + await step.process_async() + + labels = inner.call_args.kwargs["attempt_labels"] + assert labels["foo"] == "bar" # caller labels preserved + assert labels[ADAPTIVE_TECHNIQUE_LABEL] == "a" + assert labels[ADAPTIVE_ATTEMPT_LABEL] == "1" + + async def test_uses_per_step_adaptive_context(self, target, selector, seed_group): + # Two techniques; "b" has been heavily rewarded under context "violence". + bundles = { + "a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS]), + "b": _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]), + } + for _ in range(5): + selector.record_outcome(context="violence", technique="b", success=True) + for _ in range(5): + selector.record_outcome(context="violence", technique="a", success=False) + + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + adaptive_context="violence", + ) + inner = _patch_inner(step=step, bundles=bundles) + await step.process_async() + + # Exploit should have picked "b" first. + chosen_bundle = inner.call_args.kwargs["bundle"] + assert chosen_bundle is bundles["b"] + + async def test_default_context_is_global(self, target, selector, seed_group): + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])} + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + ) + _patch_inner(step=step, bundles=bundles) + await step.process_async() + + # The global context bucket received the update. + assert selector.counts(context=GLOBAL_CONTEXT, technique="a") == (1, 1) + + async def test_metadata_records_adaptive_trail_on_step_result_and_attack_result(self, target, selector, seed_group): + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.FAILURE, AttackOutcome.SUCCESS])} + step = AdaptiveStep( + atomic_attack_name="step-trail", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=3, + ) + _patch_inner(step=step, bundles=bundles) + result = await step.process_async() + + expected_trail = [ + {"technique": "a", "outcome": "failure"}, + {"technique": "a", "outcome": "success"}, + ] + # Trail surfaces on both ScenarioStepResult.metadata (for the orchestrator + # / step_identifier) and on the wrapped AttackResult.metadata (for the + # persisted row that downstream consumers like the rehydration path read). + assert result.metadata["adaptive_attempts"] == expected_trail + assert result.metadata["adaptive_context"] == GLOBAL_CONTEXT + assert result.metadata["step_name"] == "step-trail" + assert result.attack_results[0].metadata["adaptive_attempts"] == expected_trail + assert result.attack_results[0].metadata["adaptive_context"] == GLOBAL_CONTEXT + + async def test_returns_fresh_attack_result_distinct_from_inner(self, target, selector, seed_group): + # The step must NOT return the inner attack's ``AttackResult`` instance — + # doing so would cause a duplicate-PK insert when both the inner attack's + # and the scenario's post-execute persistence paths try to write the same + # row. Verify the wrapped result has a fresh ``attack_result_id`` while + # preserving the inner's identifying fields and stamping the trail. + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])} + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + ) + inner_ids: list[str] = [] + + async def _spy(*, bundle, attempt_labels): + inner_result = AttackResult( + conversation_id="conv-a-0", + objective="obj", + outcome=AttackOutcome.SUCCESS, + ) + inner_ids.append(inner_result.attack_result_id) + return inner_result + + step._run_inner_attack_async = AsyncMock(side_effect=_spy) # type: ignore[method-assign] + + result = await step.process_async() + outer = result.attack_results[0] + + assert len(inner_ids) == 1 + assert outer.attack_result_id != inner_ids[0] + assert outer.conversation_id == "conv-a-0" + assert outer.outcome == AttackOutcome.SUCCESS + assert outer.metadata["adaptive_attempts"] == [{"technique": "a", "outcome": "success"}] + assert outer.metadata["adaptive_context"] == GLOBAL_CONTEXT + + +@pytest.mark.usefixtures("patch_central_database") +class TestIdentifier: + def test_identifier_nests_technique_identifiers(self, target, selector, seed_group): + bundle_a = _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS]) + bundle_b = _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]) + # Stub each technique's identifier so the call surface is observable. + from pyrit.identifiers import ComponentIdentifier + + bundle_a.attack.get_identifier.return_value = ComponentIdentifier(class_name="A", class_module="test") + bundle_b.attack.get_identifier.return_value = ComponentIdentifier(class_name="B", class_module="test") + step = AdaptiveStep( + atomic_attack_name="step-id", + objective_target=target, + techniques={"a": bundle_a, "b": bundle_b}, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=5, + adaptive_context="violence", + ) + + identifier = step.get_identifier() + # Params capture the step's behavioral identity. + assert identifier.params["atomic_attack_name"] == "step-id" + assert identifier.params["outputs"] == ["success", "exhausted"] + assert identifier.params["max_attempts_per_objective"] == 5 + assert identifier.params["adaptive_context"] == "violence" + # Children nest each technique's identifier under "techniques", + # sorted by technique name for deterministic hash stability. + nested = identifier.children["techniques"] + assert isinstance(nested, list) + assert [child.class_name for child in nested] == ["A", "B"] + + +@pytest.mark.usefixtures("patch_central_database") +class TestPersistAdaptiveContextLabel: + """The scenario stamps ADAPTIVE_CONTEXT_LABEL into memory_labels at construction.""" + + async def test_context_label_round_trips_through_attempt_labels(self, target, selector, seed_group): + bundles = {"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])} + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques=bundles, + selector=selector, + seed_group=seed_group, + memory_labels={ADAPTIVE_CONTEXT_LABEL: "violence"}, + adaptive_context="violence", + ) + inner = _patch_inner(step=step, bundles=bundles) + await step.process_async() + + labels = inner.call_args.kwargs["attempt_labels"] + assert labels[ADAPTIVE_CONTEXT_LABEL] == "violence" diff --git a/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py b/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py index 4be4ffbb6..adc209bc9 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py +++ b/tests/unit/scenario/scenarios/adaptive/test_dispatcher.py @@ -21,6 +21,13 @@ AdaptiveTechniqueSelector, ) +# ``AdaptiveDispatchAttack`` is deprecated as of 0.15.0 in favor of +# ``AdaptiveStep`` (see ``test_adaptive_step.py``). Suppress the per-instantiation +# DeprecationWarning here so the regression suite for the dispatcher's existing +# behavior stays clean during the deprecation window. The warning is asserted +# explicitly in a dedicated test below. +pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") + def _make_bundle(*, name: str, outcomes: list[AttackOutcome], seed_technique=None) -> TechniqueBundle: """Build a TechniqueBundle whose attack stub yields the given outcomes in order. @@ -301,3 +308,16 @@ def test_validate_accepts_normal_objective(self, target, selector, seed_group): ) # Does not raise. dispatcher._validate_context(context=_make_context(objective="ok")) + + +@pytest.mark.usefixtures("patch_central_database") +@pytest.mark.filterwarnings("default::DeprecationWarning") +class TestDeprecation: + def test_instantiation_emits_deprecation_warning(self, target, selector, seed_group): + with pytest.warns(DeprecationWarning, match="AdaptiveDispatchAttack.*AdaptiveStep"): + AdaptiveDispatchAttack( + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index 12b1a45e2..dd1276ee6 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -15,9 +15,9 @@ from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.dataset_configuration import DatasetConfiguration from pyrit.scenario.core.scenario import BaselinePolicy +from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep from pyrit.scenario.scenarios.adaptive.dispatcher import ( ADAPTIVE_CONTEXT_LABEL, - AdaptiveDispatchAttack, ) from pyrit.scenario.scenarios.adaptive.selector import ( GLOBAL_CONTEXT, @@ -201,13 +201,11 @@ async def test_atomics_share_one_selector_across_dispatchers(self, mock_objectiv mock_objective_scorer=mock_objective_scorer, seed_groups=groups, ) - dispatchers = [atomic._attack_technique.attack for atomic in attacks] - # Each objective gets its own dispatcher (bound to its own seed group)... - assert len({id(d) for d in dispatchers}) == len(attacks) - for d in dispatchers: - assert isinstance(d, AdaptiveDispatchAttack) + # Each objective is now driven by its own AdaptiveStep instance... + assert all(isinstance(step, AdaptiveStep) for step in attacks) + assert len({id(step) for step in attacks}) == len(attacks) # ...but they all share the same selector so learning is global. - selectors = {id(d._selector) for d in dispatchers} + selectors = {id(step._selector) for step in attacks} assert len(selectors) == 1 async def test_global_context_label_when_using_global_extractor(self, mock_objective_target, mock_objective_scorer): @@ -302,12 +300,12 @@ async def test_techniques_with_seed_technique_are_kept(self, mock_objective_targ attacks = scenario._atomic_attacks assert len(attacks) == 1 - dispatcher = attacks[0]._attack_technique.attack - assert isinstance(dispatcher, AdaptiveDispatchAttack) + step = attacks[0] + assert isinstance(step, AdaptiveStep) # Both factories survive; in particular the seeded one is no longer # silently dropped. - assert "prompt_sending" in dispatcher._techniques - assert "many_shot" in dispatcher._techniques + assert "prompt_sending" in step._techniques + assert "many_shot" in step._techniques async def test_incompatible_seed_technique_is_filtered_per_objective( self, mock_objective_target, mock_objective_scorer @@ -336,11 +334,12 @@ async def test_incompatible_seed_technique_is_filtered_per_objective( attacks = scenario._atomic_attacks assert len(attacks) == 1 - dispatcher = attacks[0]._attack_technique.attack + step = attacks[0] + assert isinstance(step, AdaptiveStep) # Only the plain technique survives; the seed_technique-bearing one is filtered out # because is_compatible_with_technique returned False. - assert "prompt_sending" in dispatcher._techniques - assert "many_shot" not in dispatcher._techniques + assert "prompt_sending" in step._techniques + assert "many_shot" not in step._techniques async def test_objective_skipped_when_no_compatible_techniques( self, mock_objective_target, mock_objective_scorer, caplog From e548dd29293e83a2945d24e051ef53617ad9be4b Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:53:20 -0700 Subject: [PATCH 08/42] TEST: augment OutcomeScorer coverage (Phase 10a) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../score/decorators/test_outcome_scorer.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/tests/unit/score/decorators/test_outcome_scorer.py b/tests/unit/score/decorators/test_outcome_scorer.py index a7e60ef94..bcb63130f 100644 --- a/tests/unit/score/decorators/test_outcome_scorer.py +++ b/tests/unit/score/decorators/test_outcome_scorer.py @@ -3,6 +3,9 @@ """Tests for ``pyrit.score.decorators.outcome_scorer``.""" +from __future__ import annotations + +from typing import TYPE_CHECKING from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 @@ -13,6 +16,9 @@ from pyrit.score import Scorer from pyrit.score.decorators import OutcomeScorer +if TYPE_CHECKING: + from collections.abc import Callable + def _make_score(*, value: str, score_type: str = "true_false") -> Score: return Score( @@ -155,3 +161,131 @@ async def test_resolve_outcome_matches_against_any_score_in_list(): ) label = await outer.resolve_outcome_async(_make_message()) assert label == "hit" + + +def test_unscored_sentinel_value_is_stable(): + """Lock the public sentinel string so downstream policies can declare it.""" + assert OutcomeScorer.UNSCORED == "unscored" + + +async def test_resolve_outcome_returns_unscored_when_wrapped_scorer_returns_none(): + """`if not scores` must treat ``None`` like an empty list, not crash.""" + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=None) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={"hit": lambda s: True}, + ) + + label = await outer.resolve_outcome_async(_make_message()) + assert label == OutcomeScorer.UNSCORED + assert label in outer.outcomes + + +async def test_resolve_outcome_unscored_is_declared_in_outcomes(): + """When no predicate matches, the returned sentinel must be in ``outcomes``.""" + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[_make_score(value="0.5", score_type="float_scale")]) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={ + "violation": lambda s: s.score_value == "true", + "refusal": lambda s: s.score_value == "false", + }, + ) + + label = await outer.resolve_outcome_async(_make_message()) + assert label == OutcomeScorer.UNSCORED + assert label in outer.outcomes + + +@pytest.mark.parametrize( + ("scorer_values", "score_type"), + [ + ([], "true_false"), + ([None], "true_false"), + (["true"], "true_false"), + (["false"], "true_false"), + (["0.5"], "float_scale"), + (["false", "true"], "true_false"), + (["true", "false"], "true_false"), + (["0.0", "0.5", "1.0"], "float_scale"), + ], +) +async def test_resolved_label_is_always_in_declared_outcomes(scorer_values, score_type): + """Invariant: every label ``resolve_outcome_async`` returns is in ``outcomes``.""" + if scorer_values == [None]: + return_value = None + else: + return_value = [_make_score(value=v, score_type=score_type) for v in scorer_values] + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=return_value) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={ + "violation": lambda s: s.score_value == "true", + "refusal": lambda s: s.score_value == "false", + }, + ) + + label = await outer.resolve_outcome_async(_make_message()) + assert label in outer.outcomes + + +async def test_resolve_outcome_propagates_wrapped_scorer_exception(): + """Errors from the wrapped scorer must bubble up; the wrapper is not a swallow.""" + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(side_effect=RuntimeError("boom")) + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={"hit": lambda s: True}, + ) + + with pytest.raises(RuntimeError, match="boom"): + await outer.resolve_outcome_async(_make_message()) + + +async def test_resolve_outcome_propagates_predicate_exception(): + """Errors from a predicate must bubble up rather than degrading to ``unscored``.""" + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[_make_score(value="true")]) + + def _bad_predicate(score: Score) -> bool: + raise ValueError("predicate failure") + + outer = OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={"hit": _bad_predicate}, + ) + + with pytest.raises(ValueError, match="predicate failure"): + await outer.resolve_outcome_async(_make_message()) + + +async def test_init_defensively_copies_outcome_map(): + """Mutating the input ``outcome_map`` after init must not affect resolution.""" + scorer = MagicMock(spec=Scorer) + scorer.score_async = AsyncMock(return_value=[_make_score(value="true")]) + outcome_map: dict[str, Callable[[Score], bool]] = { + "violation": lambda s: s.score_value == "true", + } + outer = OutcomeScorer(wrapped_scorer=scorer, outcome_map=outcome_map) + + outcome_map.clear() + outcome_map["other"] = lambda s: True + + assert outer.outcomes == ["violation", OutcomeScorer.UNSCORED] + label = await outer.resolve_outcome_async(_make_message()) + assert label == "violation" + + +def test_outcome_scorer_is_not_a_scorer_subclass(): + """``OutcomeScorer`` is a composition wrapper; the canonical identity path is + ``wrapped_scorer.get_identifier()``. Locking this in prevents accidental + subclassing that would shift identifier composition semantics.""" + inner = MagicMock(spec=Scorer) + outer = OutcomeScorer(wrapped_scorer=inner, outcome_map={"hit": lambda s: True}) + + assert not isinstance(outer, Scorer) + assert not hasattr(outer, "get_identifier") + assert outer.wrapped_scorer is inner From f2f9f08fb7d31cf04f83c9ed4fd2032b3ee86582 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:54:13 -0700 Subject: [PATCH 09/42] TEST: augment ScenarioStep + adapter coverage (Phase 10b) Adds regression coverage for the Phase 2 ScenarioStep ABC and the AtomicAttack ScenarioStep adapter: - ScenarioStepResult: outcome is required; metadata/attack_results default factories produce fresh per-instance containers (Python mutable-default footgun); accepts all four fields when provided. - ScenarioStep ABC: subclass missing process_async cannot instantiate; subclass that overrides only process_async inherits the default _build_identifier. - AtomicAttack adapter: filter_seed_groups_by_objectives is keyword-only and correctly filters/preserves/empties seed_groups + objectives. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../test_atomic_attack_scenario_step.py | 61 ++++++++++++++-- tests/unit/scenario/test_scenario_step.py | 69 +++++++++++++++++++ 2 files changed, 124 insertions(+), 6 deletions(-) diff --git a/tests/unit/scenario/test_atomic_attack_scenario_step.py b/tests/unit/scenario/test_atomic_attack_scenario_step.py index c3593a720..7accffc80 100644 --- a/tests/unit/scenario/test_atomic_attack_scenario_step.py +++ b/tests/unit/scenario/test_atomic_attack_scenario_step.py @@ -140,9 +140,7 @@ def test_identifier_hash_differs_when_name_differs(self, mock_attack, seed_group class TestAtomicAttackProcessAsync: """``process_async`` wraps ``run_async`` into a ``ScenarioStepResult``.""" - async def test_returns_scenario_step_result_with_done_outcome( - self, mock_attack, seed_groups, attack_results - ): + async def test_returns_scenario_step_result_with_done_outcome(self, mock_attack, seed_groups, attack_results): atomic = AtomicAttack( attack_technique=AttackTechnique(attack=mock_attack), seed_groups=seed_groups, @@ -161,9 +159,7 @@ async def test_returns_scenario_step_result_with_done_outcome( assert result.outcome == "done" assert result.attack_results == attack_results - async def test_metadata_carries_incomplete_objectives( - self, mock_attack, seed_groups, attack_results - ): + async def test_metadata_carries_incomplete_objectives(self, mock_attack, seed_groups, attack_results): atomic = AtomicAttack( attack_technique=AttackTechnique(attack=mock_attack), seed_groups=seed_groups, @@ -213,3 +209,56 @@ async def test_returns_empty_results_when_no_completions(self, mock_attack, seed assert result.outcome == "done" assert result.attack_results == [] assert len(result.metadata["incomplete_objectives"]) == 1 + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackFilterSeedGroupsByObjectives: + """``filter_seed_groups_by_objectives`` is part of the duck-typed ScenarioStep surface.""" + + def test_remaining_objectives_is_keyword_only(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + with pytest.raises(TypeError): + atomic.filter_seed_groups_by_objectives(["obj1"]) # type: ignore[misc] + + def test_filter_drops_groups_not_in_remaining(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + atomic.filter_seed_groups_by_objectives(remaining_objectives=["obj2"]) + assert atomic.objectives == ["obj2"] + assert len(atomic.seed_groups) == 1 + + def test_filter_keeps_all_when_all_remain(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + atomic.filter_seed_groups_by_objectives(remaining_objectives=["obj1", "obj2"]) + assert atomic.objectives == ["obj1", "obj2"] + assert len(atomic.seed_groups) == 2 + + def test_filter_drops_all_when_none_remain(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + atomic.filter_seed_groups_by_objectives(remaining_objectives=[]) + assert atomic.objectives == [] + assert atomic.seed_groups == [] + + def test_filter_ignores_unknown_objectives(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + atomic.filter_seed_groups_by_objectives(remaining_objectives=["obj1", "does_not_exist"]) + assert atomic.objectives == ["obj1"] diff --git a/tests/unit/scenario/test_scenario_step.py b/tests/unit/scenario/test_scenario_step.py index 914882b64..3e7fca120 100644 --- a/tests/unit/scenario/test_scenario_step.py +++ b/tests/unit/scenario/test_scenario_step.py @@ -67,3 +67,72 @@ def test_step_result_defaults(): result = ScenarioStepResult(outcome="done") assert result.attack_results == [] assert result.step_identifier is None + + +def test_step_result_outcome_is_required(): + with pytest.raises(TypeError): + ScenarioStepResult() # type: ignore[call-arg] + + +def test_step_result_metadata_defaults_to_fresh_dict_per_instance(): + first = ScenarioStepResult(outcome="done") + second = ScenarioStepResult(outcome="done") + assert first.metadata == {} + assert first.metadata is not second.metadata + first.metadata["k"] = "v" + assert second.metadata == {} + + +def test_step_result_attack_results_defaults_to_fresh_list_per_instance(): + first = ScenarioStepResult(outcome="done") + second = ScenarioStepResult(outcome="done") + assert first.attack_results == [] + assert first.attack_results is not second.attack_results + first.attack_results.append("sentinel") # type: ignore[arg-type] + assert second.attack_results == [] + + +def test_step_result_accepts_all_fields(): + identifier = ComponentIdentifier.of(_ConcreteStep(), params={"name": "x", "outputs": ["done"]}) + metadata = {"step_name": "x", "extra": 1} + result = ScenarioStepResult( + outcome="success", + attack_results=[], + step_identifier=identifier, + metadata=metadata, + ) + assert result.outcome == "success" + assert result.step_identifier is identifier + assert result.metadata == metadata + + +class _StepWithoutProcessAsync(ScenarioStep): + """Subclass that forgets to implement ``process_async``.""" + + def __init__(self) -> None: + self.name = "incomplete" + self.outputs = ["done"] + + +def test_subclass_missing_process_async_cannot_instantiate(): + with pytest.raises(TypeError, match="process_async"): + _StepWithoutProcessAsync() # type: ignore[abstract] + + +class _StepWithDefaultIdentifier(ScenarioStep): + """Subclass that overrides only ``process_async`` — inherits identifier behavior.""" + + def __init__(self) -> None: + self.name = "inherits_identifier" + self.outputs = ["done", "skipped"] + + async def process_async(self) -> ScenarioStepResult: + return ScenarioStepResult(outcome="done") + + +def test_subclass_inherits_default_build_identifier(): + step = _StepWithDefaultIdentifier() + identifier = step.get_identifier() + assert identifier.params["name"] == "inherits_identifier" + assert identifier.params["outputs"] == ["done", "skipped"] + assert identifier.children == {} From 30f6deeb5f822ee556ba3870930610065bcb1fbc Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:54:36 -0700 Subject: [PATCH 10/42] TEST: augment StrategyGraph + state coverage (Phase 10c) Adds tests for the Phase 3 state-machine layer covering gaps in the existing suite: - Performance: counting-mock assertion that an N-step linear graph invokes exactly N policy actions (guards against N**2 retraversal). - State correctness: terminal_states is immune to external mutation of the input set; multi-terminal policies can reach an alternate terminal (FAILED, not just COMPLETE). - Determinism: history ordering is identical across reset + re-run. - Branching dispatch: parametrized 3-way branch confirms transitions are dict-lookup based rather than isinstance chains. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/scenario/test_strategy_graph.py | 98 ++++++++++++++++++- .../scenario/test_strategy_graph_branching.py | 51 ++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/tests/unit/scenario/test_strategy_graph.py b/tests/unit/scenario/test_strategy_graph.py index b74e213e8..b0ed4e141 100644 --- a/tests/unit/scenario/test_strategy_graph.py +++ b/tests/unit/scenario/test_strategy_graph.py @@ -4,6 +4,7 @@ """Tests for ``pyrit.scenario.core.strategy_graph``.""" from enum import Enum +from unittest.mock import AsyncMock import pytest @@ -28,7 +29,6 @@ def _make_graph(*, policy: StrategyPolicy) -> StrategyGraph: class TestStrategyPolicyInit: - def test_requires_non_empty_terminal_states(self): with pytest.raises(ValueError, match="at least one terminal state"): StrategyPolicy( @@ -105,6 +105,25 @@ async def _action(graph): assert policy.is_terminal(state=_State.END) assert not policy.is_terminal(state=_State.START) + def test_terminal_states_frozen_against_input_mutation(self): + """Mutating the original set passed in must not leak into the policy.""" + + async def _action(graph): + return _State.END, None + + mutable_terminals: set = {_State.END} + policy: StrategyPolicy = StrategyPolicy( + actions={_State.START: _action}, + initial_state=_State.START, + terminal_states=mutable_terminals, # type: ignore[arg-type] + ) + + mutable_terminals.add(_State.MIDDLE) + + assert isinstance(policy.terminal_states, frozenset) + assert policy.terminal_states == frozenset({_State.END}) + assert not policy.is_terminal(state=_State.MIDDLE) + # --------------------------------------------------------------------------- # StrategyGraph traversal @@ -296,3 +315,80 @@ async def executing_action(graph): results = [r async for r in graph.event_loop_async()] assert [r.outcome for r in results] == ["finished"] assert graph.current_state == ScenarioCoreState.COMPLETE + + +async def test_event_loop_invokes_each_action_exactly_once_in_linear_graph(): + """Performance guard: an N-state linear graph triggers exactly N actions, not N**2.""" + n = 10 + actions: dict[int, AsyncMock] = { + i: AsyncMock(return_value=(i + 1, ScenarioStepResult(outcome=f"step_{i}"))) for i in range(n) + } + + graph: StrategyGraph[object, int] = StrategyGraph( + policy=StrategyPolicy( + actions=actions, + initial_state=0, + terminal_states=frozenset({n}), + ), + ) + + results = [r async for r in graph.event_loop_async()] + + assert len(results) == n + for i, action in actions.items(): + assert action.call_count == 1, f"action[{i}] invoked {action.call_count} times" + assert sum(action.call_count for action in actions.values()) == n + + +async def test_event_loop_reaches_alternate_terminal_state(): + """A policy with multiple terminals must allow stopping on any of them.""" + + async def failing_action(graph): + return ScenarioCoreState.FAILED, ScenarioStepResult(outcome="aborted") + + graph = _make_graph( + policy=StrategyPolicy( + actions={ScenarioCoreState.EXECUTING: failing_action}, + initial_state=ScenarioCoreState.EXECUTING, + terminal_states=frozenset({ScenarioCoreState.COMPLETE, ScenarioCoreState.FAILED}), + ), + ) + + results = [r async for r in graph.event_loop_async()] + + assert [r.outcome for r in results] == ["aborted"] + assert graph.current_state == ScenarioCoreState.FAILED + assert graph.is_terminal + + +async def test_history_order_is_deterministic_across_runs(): + """Two consecutive runs of the same graph yield identical history orderings.""" + n = 5 + + def _make_action(*, index: int): + async def _action(graph): + return index + 1, ScenarioStepResult(outcome=f"step_{index}") + + return _action + + actions = {i: _make_action(index=i) for i in range(n)} + + graph: StrategyGraph[object, int] = StrategyGraph( + policy=StrategyPolicy( + actions=actions, + initial_state=0, + terminal_states=frozenset({n}), + ), + ) + + _ = [r async for r in graph.event_loop_async()] + first_run = [(state, result.outcome) for state, result in graph.history] + + graph.reset() + _ = [r async for r in graph.event_loop_async()] + second_run = [(state, result.outcome) for state, result in graph.history] + + expected = [(i, f"step_{i}") for i in range(n)] + assert first_run == expected + assert second_run == expected + assert first_run == second_run diff --git a/tests/unit/scenario/test_strategy_graph_branching.py b/tests/unit/scenario/test_strategy_graph_branching.py index 2d98cad29..86a99aaa3 100644 --- a/tests/unit/scenario/test_strategy_graph_branching.py +++ b/tests/unit/scenario/test_strategy_graph_branching.py @@ -24,6 +24,8 @@ from enum import Enum +import pytest + from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy @@ -153,3 +155,52 @@ async def test_reset_replays_branching_graph(): assert [r.outcome for r in results] == ["violation", "done"] assert sweep.call_count == 2 assert escalation.call_count == 2 + + +# --------------------------------------------------------------------------- +# 3-way dispatch — confirms transitions are dict-lookup based, not isinstance chains. +# --------------------------------------------------------------------------- + + +class _DispatchState(str, Enum): + DISPATCH = "dispatch" + BRANCH_A = "branch_a" + BRANCH_B = "branch_b" + BRANCH_C = "branch_c" + COMPLETE = "complete" + + +def _build_three_way_graph(*, target: _DispatchState) -> StrategyGraph: + async def dispatch_action(graph): + return target, ScenarioStepResult(outcome=f"to_{target.value}") + + async def branch_action(graph): + return _DispatchState.COMPLETE, ScenarioStepResult(outcome=f"reached_{graph.current_state.value}") + + policy: StrategyPolicy[ScenarioStep, _DispatchState] = StrategyPolicy( + actions={ + _DispatchState.DISPATCH: dispatch_action, + _DispatchState.BRANCH_A: branch_action, + _DispatchState.BRANCH_B: branch_action, + _DispatchState.BRANCH_C: branch_action, + }, + initial_state=_DispatchState.DISPATCH, + terminal_states=frozenset({_DispatchState.COMPLETE}), + ) + return StrategyGraph(policy=policy) + + +@pytest.mark.parametrize( + "target", + [_DispatchState.BRANCH_A, _DispatchState.BRANCH_B, _DispatchState.BRANCH_C], +) +async def test_three_way_branch_dispatches_to_target_state(target: _DispatchState) -> None: + graph = _build_three_way_graph(target=target) + + results = [r async for r in graph.event_loop_async()] + + assert [r.outcome for r in results] == [f"to_{target.value}", f"reached_{target.value}"] + states_before = [state for state, _ in graph.history] + assert states_before == [_DispatchState.DISPATCH, target] + assert graph.current_state == _DispatchState.COMPLETE + assert graph.is_terminal From 9e4657428e89d49b8ad257e70e2d6997d694bce8 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:54:58 -0700 Subject: [PATCH 11/42] TEST: augment adaptive scenario migration coverage (Phase 10f) Adds high-signal tests for the Phase 6b adaptive scenario migration: * AdaptiveStep is a ScenarioStep subclass (not AtomicAttack), with name aliasing atomic_attack_name for resume bookkeeping. * _build_identifier output is stable when techniques dict is constructed in reversed key order. * _build_adaptive_linear_policy + _build_execution_graph build a StrategyPolicy[ScenarioStep, int] with initial_state=0, terminal_states={len(steps)}, and one action per pre-terminal state. * Event loop visits each step exactly once, terminates, propagates 'success'/'exhausted' outcomes verbatim, and binds/unbinds current_step around each action. * End-to-end smoke: a real AdaptiveStep plugged into the adaptive linear policy emits 'success' as a real transition label. Test count: 83 -> 93 (10 new tests). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenarios/adaptive/test_adaptive_step.py | 60 +++++++- .../scenarios/adaptive/test_text_adaptive.py | 141 +++++++++++++++++- 2 files changed, 198 insertions(+), 3 deletions(-) diff --git a/tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py b/tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py index a80d380bd..a2759ebd0 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py +++ b/tests/unit/scenario/scenarios/adaptive/test_adaptive_step.py @@ -11,7 +11,8 @@ import pytest from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup, SeedObjective -from pyrit.scenario.core.scenario_step import ScenarioStepResult +from pyrit.scenario.core.atomic_attack import AtomicAttack +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep from pyrit.scenario.scenarios.adaptive.dispatcher import ( ADAPTIVE_ATTEMPT_LABEL, @@ -397,6 +398,63 @@ def test_identifier_nests_technique_identifiers(self, target, selector, seed_gro assert isinstance(nested, list) assert [child.class_name for child in nested] == ["A", "B"] + def test_identifier_is_stable_under_reversed_technique_order(self, target, selector, seed_group): + # The _build_identifier sort key (`sorted(self._techniques.items())`) + # must collapse reversed input dicts to the same identifier, otherwise + # step_identifier hashes would drift between runs that pick the same + # techniques in different insertion orders. + from pyrit.identifiers import ComponentIdentifier + + bundle_a = _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS]) + bundle_b = _make_bundle(name="b", outcomes=[AttackOutcome.SUCCESS]) + bundle_a.attack.get_identifier.return_value = ComponentIdentifier(class_name="A", class_module="test") + bundle_b.attack.get_identifier.return_value = ComponentIdentifier(class_name="B", class_module="test") + + def _make(techniques): + return AdaptiveStep( + atomic_attack_name="step-id", + objective_target=target, + techniques=techniques, + selector=selector, + seed_group=seed_group, + max_attempts_per_objective=5, + adaptive_context="violence", + ) + + forward = _make({"a": bundle_a, "b": bundle_b}).get_identifier() + reversed_ = _make({"b": bundle_b, "a": bundle_a}).get_identifier() + assert forward == reversed_ + + +@pytest.mark.usefixtures("patch_central_database") +class TestScenarioStepIdentity: + """Phase 6b invariant: AdaptiveStep is a ScenarioStep, NOT an AtomicAttack.""" + + def test_is_scenario_step_not_atomic_attack(self, target, selector, seed_group): + step = AdaptiveStep( + atomic_attack_name="step", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + assert isinstance(step, ScenarioStep) + assert not isinstance(step, AtomicAttack) + + def test_name_aliases_atomic_attack_name_for_resume_bookkeeping(self, target, selector, seed_group): + # The orchestrator's resume filter reads ``step.name`` uniformly across + # step types; AdaptiveStep must alias it to atomic_attack_name without + # subclassing AtomicAttack. + step = AdaptiveStep( + atomic_attack_name="step-resume-key", + objective_target=target, + techniques={"a": _make_bundle(name="a", outcomes=[AttackOutcome.SUCCESS])}, + selector=selector, + seed_group=seed_group, + ) + assert step.name == "step-resume-key" + assert step.name == step.atomic_attack_name + @pytest.mark.usefixtures("patch_central_database") class TestPersistAdaptiveContextLabel: diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index dd1276ee6..bd5788d73 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -5,16 +5,18 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pyrit.identifiers import ComponentIdentifier -from pyrit.models import SeedAttackGroup, SeedObjective +from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup, SeedObjective from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.dataset_configuration import DatasetConfiguration from pyrit.scenario.core.scenario import BaselinePolicy +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep from pyrit.scenario.scenarios.adaptive.dispatcher import ( ADAPTIVE_CONTEXT_LABEL, @@ -509,3 +511,138 @@ async def test_initialize_async_rejects_explicit_baseline(self, mock_objective_t objective_target=mock_objective_target, include_baseline=True, ) + + +def _make_stub_step(*, name: str, outcome: str = "success") -> MagicMock: + """Build a ScenarioStep-spec stub whose process_async returns a fixed result.""" + step = MagicMock(spec=ScenarioStep) + step.name = name + step.outputs = ["success", "exhausted"] + step.process_async = AsyncMock(return_value=ScenarioStepResult(outcome=outcome, attack_results=[], metadata={})) + return step + + +@pytest.mark.usefixtures(*FIXTURES) +class TestAdaptiveLinearPolicy: + """The adaptive policy must dispatch via process_async with int states 0..N.""" + + def test_empty_steps_raises(self, mock_objective_scorer): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + with pytest.raises(ValueError, match="at least one step"): + scenario._build_adaptive_linear_policy(steps=[]) + + def test_initial_state_zero_and_terminal_state_is_step_count(self, mock_objective_scorer): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + steps = [_make_stub_step(name=f"s{i}") for i in range(3)] + policy = scenario._build_adaptive_linear_policy(steps=steps) + + assert isinstance(policy, StrategyPolicy) + assert policy.initial_state == 0 + assert policy.terminal_states == frozenset({3}) + # One action per non-terminal state; terminal state must not have an action. + assert set(policy.actions.keys()) == {0, 1, 2} + + def test_execution_graph_wraps_policy(self, mock_objective_scorer): + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + steps = [_make_stub_step(name="s0"), _make_stub_step(name="s1")] + graph = scenario._build_execution_graph(steps=steps) + + assert isinstance(graph, StrategyGraph) + assert graph.policy.initial_state == 0 + assert graph.policy.terminal_states == frozenset({2}) + + async def test_event_loop_visits_each_step_exactly_once_and_terminates(self, mock_objective_scorer): + # Guards against infinite loops and re-entry: each step's + # process_async must fire once and the graph must reach the terminal + # state without revisiting any state. + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + steps = [_make_stub_step(name=f"s{i}", outcome="exhausted") for i in range(4)] + graph = scenario._build_execution_graph(steps=steps) + + states_visited: list[int] = [] + results: list[ScenarioStepResult] = [] + async for result in graph.event_loop_async(): + states_visited.append(graph.current_state) + results.append(result) + + for step in steps: + assert step.process_async.call_count == 1 + # history records (state_before, result) pairs for the four pre-terminal states. + assert [state for state, _ in graph.history] == [0, 1, 2, 3] + assert len(results) == 4 + assert graph.is_terminal + assert graph.current_state == 4 + + async def test_action_preserves_step_outcome_label(self, mock_objective_scorer): + # The adaptive policy intentionally does NOT collapse outcomes to + # "completed" the way the default policy's AtomicAttack branch does; + # "success" and "exhausted" must propagate verbatim. + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + success_step = _make_stub_step(name="ok", outcome="success") + exhausted_step = _make_stub_step(name="fail", outcome="exhausted") + graph = scenario._build_execution_graph(steps=[success_step, exhausted_step]) + + outcomes = [r.outcome async for r in graph.event_loop_async()] + assert outcomes == ["success", "exhausted"] + + async def test_action_binds_current_step_on_graph(self, mock_objective_scorer): + # The adaptive _action wraps process_async with bind_current_step so + # external observers (e.g. the Scenario orchestrator) can read which + # step is running. Capture graph.current_step from inside the stub. + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + + observed: list[ScenarioStep | None] = [] + + async def _capture(): + observed.append(graph.current_step) + return ScenarioStepResult(outcome="success") + + spy_step = MagicMock(spec=ScenarioStep) + spy_step.name = "spy" + spy_step.outputs = ["success", "exhausted"] + spy_step.process_async = AsyncMock(side_effect=_capture) + + graph = scenario._build_execution_graph(steps=[spy_step]) + async for _ in graph.event_loop_async(): + pass + + assert observed == [spy_step] + # The finally block clears the binding after the action returns. + assert graph.current_step is None + + async def test_adaptive_step_returning_success_runs_through_policy( + self, mock_objective_target, mock_objective_scorer + ): + # End-to-end-ish integration: a real AdaptiveStep instance plugged + # into the adaptive linear policy emits "success" as a real + # transition label (regression guard against the default policy's + # AtomicAttack-only dispatch path swallowing the outcome). + import random + + from pyrit.scenario.scenarios.adaptive.dispatcher import TechniqueBundle + from pyrit.scenario.scenarios.adaptive.selector import AdaptiveTechniqueSelector + + bundle_attack = MagicMock(name="bundle-attack") + bundle = TechniqueBundle(attack=bundle_attack) + seed_group = _make_seed_group(value="obj-x") + selector = AdaptiveTechniqueSelector(epsilon=0.0, pool_threshold=1, rng=random.Random(0)) + step = AdaptiveStep( + atomic_attack_name="adaptive_x", + objective_target=mock_objective_target, + techniques={"a": bundle}, + selector=selector, + seed_group=seed_group, + ) + + async def _stub_inner(*, bundle, attempt_labels): + return AttackResult(conversation_id="c", objective="obj-x", outcome=AttackOutcome.SUCCESS) + + step._run_inner_attack_async = AsyncMock(side_effect=_stub_inner) # type: ignore[method-assign] + + scenario = TextAdaptive(objective_scorer=mock_objective_scorer) + graph = scenario._build_execution_graph(steps=[step]) + results = [r async for r in graph.event_loop_async()] + + assert len(results) == 1 + assert results[0].outcome == "success" + assert graph.is_terminal From c2acb5e98fb0ab9efb6bbf432866684dfaec86aa Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:55:09 -0700 Subject: [PATCH 12/42] TEST: augment step_identifier persistence coverage (Phase 10d) Adds tests that fill the gaps Phase 4 left around the additive step_identifier column: - step_identifier: no false dedup across attack-execution child configs, list (not nested dict) shape, execution-order is preserved, child param changes propagate to hash, eval_version is in params. - memory interface: legacy AttackResult rows (NULL step_identifier) round-trip cleanly, and multiple results sharing one step_identifier are retrievable via a single STEP filter. - alembic: a1c2e4f80b3d revision metadata, upgrade adds the column, full upgrade->downgrade round-trip restores the pre-Phase-4 schema. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../unit/identifiers/test_step_identifier.py | 82 +++++++++++++++++++ .../test_interface_attack_results.py | 50 +++++++++++ tests/unit/memory/test_migration.py | 82 +++++++++++++++++++ 3 files changed, 214 insertions(+) diff --git a/tests/unit/identifiers/test_step_identifier.py b/tests/unit/identifiers/test_step_identifier.py index 13e8229cb..2340e2eb4 100644 --- a/tests/unit/identifiers/test_step_identifier.py +++ b/tests/unit/identifiers/test_step_identifier.py @@ -145,3 +145,85 @@ def test_hash_differs_when_step_name_differs(): def test_step_eval_version_is_positive_int(): assert isinstance(STEP_EVAL_VERSION, int) assert STEP_EVAL_VERSION >= 1 + + +def test_hash_differs_when_attack_execution_child_config_differs(): + """Same step_name and outcome but different atomic child params must NOT dedup.""" + atomic_a = _make_atomic_identifier(hash_suffix="a") + atomic_b = _make_atomic_identifier(hash_suffix="b") + + only_a = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic_a], + ) + only_b = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic_b], + ) + assert only_a.hash != only_b.hash + + +def test_attack_executions_value_is_list_not_nested_dict(): + """children["attack_executions"] must be a ``list[ComponentIdentifier]``, never a nested dict.""" + atomic = _make_atomic_identifier() + result = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic], + ) + nested = result.children["attack_executions"] + assert isinstance(nested, list) + assert not isinstance(nested, dict) + assert all(isinstance(c, ComponentIdentifier) for c in nested) + + +def test_attack_executions_preserves_input_order(): + """``attack_execution_identifiers`` is preserved in execution order (per docstring), so reversed inputs + produce a different hash. This guards against an accidental sort that would erase execution-order + semantics for branching/adaptive scenarios.""" + atomic_a = _make_atomic_identifier(hash_suffix="a") + atomic_b = _make_atomic_identifier(hash_suffix="b") + + ab = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic_a, atomic_b], + ) + ba = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[atomic_b, atomic_a], + ) + # Lists preserve caller-supplied order. + assert [c.params["marker"] for c in ab.children["attack_executions"]] == ["a", "b"] + assert [c.params["marker"] for c in ba.children["attack_executions"]] == ["b", "a"] + # Hash reflects order — these are NOT considered the same step execution. + assert ab.hash != ba.hash + + +def test_hash_changes_when_a_child_param_changes(): + """If an atomic child's params change, the step identifier's hash changes too.""" + baseline = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[_make_atomic_identifier(hash_suffix="v1")], + ) + mutated = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[_make_atomic_identifier(hash_suffix="v2")], + ) + assert baseline.hash != mutated.hash + + +def test_step_identifier_eval_version_in_params(): + """The schema/eval version is embedded in params so old rows preserve their original version.""" + result = build_step_identifier( + step_name="opening", + outcome="violation", + attack_execution_identifiers=[], + ) + assert "eval_version" in result.params + assert result.params["eval_version"] == STEP_EVAL_VERSION diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 6cc6ddf97..bfed3e93b 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -1756,6 +1756,56 @@ def test_get_attack_results_by_step_identifier_filter_skips_legacy_rows(sqlite_i assert [r.conversation_id for r in results] == ["conv_new"] +def test_legacy_attack_result_without_step_identifier_round_trips(sqlite_instance: MemoryInterface): + """A legacy AttackResult (no step_identifier) persists and loads back with step_identifier == None.""" + legacy_ar = _make_attack_result_with_identifier("conv_legacy", "CrescendoAttack") + assert legacy_ar.step_identifier is None + sqlite_instance.add_attack_results_to_memory(attack_results=[legacy_ar]) + + results = sqlite_instance.get_attack_results() + assert len(results) == 1 + assert results[0].conversation_id == "conv_legacy" + assert results[0].step_identifier is None + + +def test_multiple_attack_results_share_step_identifier(sqlite_instance: MemoryInterface): + """Two AttackResults for the same logical step share one step_identifier hash after persistence, + and a single STEP filter retrieves both.""" + from pyrit.identifiers.step_identifier import build_step_identifier + + ar1 = _make_attack_result_with_identifier("conv_share_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_share_2", "CrescendoAttack") + other = _make_attack_result_with_identifier("conv_other", "CrescendoAttack") + + shared_step = build_step_identifier( + step_name="opening_phase", + outcome="done", + attack_execution_identifiers=[ar1.atomic_attack_identifier], + ) + ar1.step_identifier = shared_step + ar2.step_identifier = shared_step + other.step_identifier = build_step_identifier( + step_name="escalation_phase", + outcome="done", + attack_execution_identifiers=[other.atomic_attack_identifier], + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, other]) + + results = sqlite_instance.get_attack_results( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.STEP, + property_path="$.hash", + value=shared_step.hash, + partial_match=False, + ) + ], + ) + assert {r.conversation_id for r in results} == {"conv_share_1", "conv_share_2"} + hashes = {r.step_identifier.hash for r in results} + assert hashes == {shared_step.hash} + + def test_get_attack_results_targeted_harm_categories_emits_deprecation_warning(sqlite_instance: MemoryInterface): """Test that passing targeted_harm_categories emits a DeprecationWarning.""" import warnings diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py index 0140ce5b1..5c6aa42a0 100644 --- a/tests/unit/memory/test_migration.py +++ b/tests/unit/memory/test_migration.py @@ -262,3 +262,85 @@ def test_generate_schema_migration_with_diffs_creates_revision(): mock_revision.assert_called_once() finally: engine.dispose() + + +# --------------------------------------------------------------------------- +# Phase 4: step_identifier additive migration (a1c2e4f80b3d) +# --------------------------------------------------------------------------- + + +def _column_names(connection, table: str) -> set[str]: + return {col["name"] for col in inspect(connection).get_columns(table)} + + +def test_step_identifier_migration_metadata(): + """Phase 4 migration script declares the correct revision chain.""" + from pyrit.memory.alembic.versions import a1c2e4f80b3d_add_step_identifier as mod + + assert mod.revision == "a1c2e4f80b3d" + assert mod.down_revision == "7a1b2c3d4e5f" + assert mod.branch_labels is None + assert mod.depends_on is None + + +def test_step_identifier_migration_upgrade_adds_column(): + """Upgrading to a1c2e4f80b3d adds the nullable step_identifier column to AttackResultEntries.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "step-upgrade.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = "pyrit_memory_alembic_version" + + # Stop just before the step_identifier migration: no column yet. + command.upgrade(config, "7a1b2c3d4e5f") + cols_before = _column_names(connection, "AttackResultEntries") + assert "step_identifier" not in cols_before + assert "atomic_attack_identifier" in cols_before # sanity: prior migration applied + + command.upgrade(config, "a1c2e4f80b3d") + cols_after = _column_names(connection, "AttackResultEntries") + assert "step_identifier" in cols_after + finally: + engine.dispose() + + +def test_step_identifier_migration_downgrade_round_trip_leaves_clean_schema(): + """Round-tripping a1c2e4f80b3d (upgrade then downgrade) restores the pre-Phase-4 schema and leaves + the rest of the schema untouched.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "step-roundtrip.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = "pyrit_memory_alembic_version" + + command.upgrade(config, "7a1b2c3d4e5f") + baseline_cols = _column_names(connection, "AttackResultEntries") + baseline_tables = set(inspect(connection).get_table_names()) + + command.upgrade(config, "a1c2e4f80b3d") + assert "step_identifier" in _column_names(connection, "AttackResultEntries") + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version == "a1c2e4f80b3d" + + command.downgrade(config, "7a1b2c3d4e5f") + after_cols = _column_names(connection, "AttackResultEntries") + after_tables = set(inspect(connection).get_table_names()) + assert "step_identifier" not in after_cols + assert after_cols == baseline_cols + assert after_tables == baseline_tables + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version == "7a1b2c3d4e5f" + finally: + engine.dispose() From 15c0275c091b85a87d76f2794160bf529fc7a2e6 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 13:59:24 -0700 Subject: [PATCH 13/42] FEAT: add BroadSweepThenDeepDive branching scenario (Phase 7) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/scenarios/airt/__init__.py | 12 + .../scenarios/airt/sweep_then_deep_dive.py | 665 ++++++++++++++++++ .../unit/scenario/scenarios/airt/__init__.py | 0 .../airt/test_sweep_then_deep_dive.py | 622 ++++++++++++++++ 4 files changed, 1299 insertions(+) create mode 100644 pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py create mode 100644 tests/unit/scenario/scenarios/airt/__init__.py create mode 100644 tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py diff --git a/pyrit/scenario/scenarios/airt/__init__.py b/pyrit/scenario/scenarios/airt/__init__.py index f4eae9657..c98a41b23 100644 --- a/pyrit/scenario/scenarios/airt/__init__.py +++ b/pyrit/scenario/scenarios/airt/__init__.py @@ -12,6 +12,13 @@ from pyrit.scenario.scenarios.airt.psychosocial import Psychosocial, PsychosocialStrategy from pyrit.scenario.scenarios.airt.rapid_response import RapidResponse from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy +from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + BroadSweepThenDeepDive, + BroadSweepThenDeepDiveStrategy, + CategoryAggregatingSweepStep, + FilteredDeepDiveStep, + SweepThenDeepDiveState, +) def __getattr__(name: str) -> Any: @@ -36,10 +43,14 @@ def __getattr__(name: str) -> Any: __all__ = [ + "BroadSweepThenDeepDive", + "BroadSweepThenDeepDiveStrategy", + "CategoryAggregatingSweepStep", "ContentHarms", "ContentHarmsStrategy", "Cyber", "CyberStrategy", + "FilteredDeepDiveStep", "Jailbreak", "JailbreakStrategy", "Leakage", @@ -50,4 +61,5 @@ def __getattr__(name: str) -> Any: "RapidResponseStrategy", "Scam", "ScamStrategy", + "SweepThenDeepDiveState", ] diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py new file mode 100644 index 000000000..b2d52ca01 --- /dev/null +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -0,0 +1,665 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``BroadSweepThenDeepDive`` — Phase 7 validation artifact for branching scenarios. + +First real consumer of the non-linear ``StrategyGraph`` policy. The flow is: + +1. **Sweep phase**: run one fast single-turn ``AtomicAttack`` across every + provided seed group. For each ``AttackResult`` produced, the configured + ``OutcomeScorer`` classifies the model's final response. Categories + (``display_group``) that emit the configured ``weakness_label`` are + tracked. +2. **Branch**: + - If at least one category was flagged, transition to **Deep dive**. + - Otherwise terminate immediately. +3. **Deep dive phase**: run each provided multi-turn ``AtomicAttack`` ONLY + against the categories the sweep flagged. Untargeted categories are + skipped; their names are stamped into ``ScenarioStepResult.metadata`` + for diagnostics. + +This file intentionally bundles its two custom ``ScenarioStep`` subclasses +(``CategoryAggregatingSweepStep``, ``FilteredDeepDiveStep``) inline rather +than promoting them to ``pyrit/scenario/core/``. Phase 9 will extract any +of them once a second scenario shows the same shape. + +Scope: this scenario takes its sweep + deep-dive ``AtomicAttack`` lists +explicitly via the constructor (no registry-driven technique selection). +The goal of Phase 7 is to prove the branching policy works end-to-end +without dragging in registry/factory machinery; subclasses with real +technique selection can override ``_build_atomic_attacks_for_phases``. +""" + +from __future__ import annotations + +import logging +from enum import Enum +from typing import TYPE_CHECKING, ClassVar, Optional, cast + +from pyrit.common import apply_defaults +from pyrit.models import Message +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.scenario import Scenario +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.scenario_strategy import ScenarioStrategy +from pyrit.scenario.core.strategy_graph import ( + PolicyAction, + StrategyGraph, + StrategyPolicy, +) +from pyrit.score.decorators.outcome_scorer import OutcomeScorer + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from pyrit.identifiers import ComponentIdentifier + from pyrit.models import AttackResult + from pyrit.scenario.core.atomic_attack import AtomicAttack + from pyrit.score import TrueFalseScorer + +logger = logging.getLogger(__name__) + + +class SweepThenDeepDiveState(str, Enum): + """ + The three states that drive ``BroadSweepThenDeepDive``. + + Inherits ``str`` for the canonical ``(str, Enum)`` design language used + by ``ScenarioCoreState``, ``BaselinePolicy``, and ``AttackOutcome``: the + enum members serialize naturally in identifiers and logs. + + The two terminal states are distinct so that downstream consumers can + tell why the scenario stopped (``ALL_SAFE`` = nothing to escalate; + ``COMPLETE`` = deep dive finished). + """ + + SWEEPING = "sweeping" + DEEP_DIVING = "deep_diving" + COMPLETE = "complete" + ALL_SAFE = "all_safe" + + +class BroadSweepThenDeepDiveStrategy(ScenarioStrategy): + """ + Single-member strategy enum for ``BroadSweepThenDeepDive``. + + This scenario doesn't run a technique-selection menu — the sweep and + deep-dive attacks are constructor inputs. The strategy enum exists only + to satisfy the base ``Scenario`` contract; a future migration to + registry-driven techniques would replace this with a dynamic + ``build_strategy_class_from_specs`` call. + """ + + DEFAULT = ("default", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + """Return the strategy aggregate tags this enum exposes.""" + return {"all"} + + +class CategoryAggregatingSweepStep(ScenarioStep): + """ + Run one ``AtomicAttack`` across all seed groups and classify each result. + + The wrapped attack is invoked exactly once via ``AtomicAttack.run_async``. + Each ``AttackResult`` produced is then handed to the configured + ``OutcomeScorer.resolve_outcome_async``; the result's ``display_group`` + serves as the category key. If any category's label matches + ``weakness_label``, the step emits ``"found_weaknesses"``; otherwise it + emits ``"all_safe"``. The set of weak category names is stamped into + ``metadata['weak_categories']`` for the downstream deep-dive step. + + Duck-types AtomicAttack-like attributes (``atomic_attack_name``, + ``display_group``, ``objectives``, ``seed_groups``, + ``filter_seed_groups_by_objectives``) so the orchestrator's existing + resume bookkeeping does not need to special-case branching scenarios. + """ + + _OUTPUTS: ClassVar[tuple[str, ...]] = ("found_weaknesses", "all_safe") + + def __init__( + self, + *, + atomic_attack_name: str, + atomic_attack: AtomicAttack, + outcome_scorer: OutcomeScorer, + weakness_label: str, + max_concurrency: int = 1, + ) -> None: + """ + Initialize the sweep step. + + Args: + atomic_attack_name (str): Step name used for the + ``ScenarioStepResult.metadata['step_name']`` stamp and for + the orchestrator's resume-by-name path. + atomic_attack (AtomicAttack): The single-turn attack to dispatch. + Its ``seed_groups``, ``display_group``, and ``objectives`` + are surfaced upward so the orchestrator sees a familiar + shape. + outcome_scorer (OutcomeScorer): Decorator that turns the wrapped + scorer's output into a transition label. Must declare + ``weakness_label`` as one of its outputs. + weakness_label (str): The label emitted by ``outcome_scorer`` + that indicates a category was breached and should be + escalated to deep dive. + max_concurrency (int): Forwarded to ``AtomicAttack.run_async``. + + Raises: + ValueError: If ``weakness_label`` is not declared in + ``outcome_scorer.outcomes``. + """ + if weakness_label not in outcome_scorer.outcomes: + raise ValueError( + f"weakness_label {weakness_label!r} is not declared as an outcome of the " + f"supplied OutcomeScorer (declared: {outcome_scorer.outcomes!r})." + ) + + self.name = atomic_attack_name + self.outputs = list(self._OUTPUTS) + self._atomic = atomic_attack + self._outcome_scorer = outcome_scorer + self._weakness_label = weakness_label + self._max_concurrency = max_concurrency + + # Duck-typed AtomicAttack-like attributes so the orchestrator's + # resume bookkeeping continues to work without changes. + self.atomic_attack_name = atomic_attack_name + self.display_group = atomic_attack.display_group + self.seed_groups = list(atomic_attack.seed_groups) + self.objectives = list(atomic_attack.objectives) + + def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) -> None: + """ + Defer to the wrapped attack's seed-group filter. + + Mirrors the ``AtomicAttack`` shape so ``Scenario._execute_scenario_async`` + can prune already-completed objectives without special-casing this + step type. + + Args: + remaining_objectives: Remaining objective strings the orchestrator + wants this step to cover. + """ + self._atomic.filter_seed_groups_by_objectives(remaining_objectives=remaining_objectives) + self.seed_groups = list(self._atomic.seed_groups) + self.objectives = list(self._atomic.objectives) + + async def process_async(self) -> ScenarioStepResult: + """ + Execute the sweep and classify each result. + + Returns: + ScenarioStepResult: Outcome is ``"found_weaknesses"`` if any + category emitted ``self._weakness_label``, else + ``"all_safe"``. ``metadata['weak_categories']`` carries the + set of flagged categories and ``metadata['category_outcomes']`` + carries the full per-category label mapping for diagnostics. + """ + executor_result = await self._atomic.run_async( + max_concurrency=self._max_concurrency, + return_partial_on_failure=True, + ) + attack_results: list[AttackResult] = list(executor_result.completed_results) + + weak_categories: set[str] = set() + category_outcomes: dict[str, str] = {} + + for result in attack_results: + category = self._category_key(result) + label = await self._classify_async(result=result) + category_outcomes[category] = label + if label == self._weakness_label: + weak_categories.add(category) + + outcome = "found_weaknesses" if weak_categories else "all_safe" + return ScenarioStepResult( + outcome=outcome, + attack_results=attack_results, + metadata={ + "weak_categories": weak_categories, + "category_outcomes": category_outcomes, + }, + ) + + async def _classify_async(self, *, result: AttackResult) -> str: + """ + Map a single ``AttackResult`` to one of the scorer's labels. + + Returns the ``UNSCORED`` sentinel when the result has no + ``last_response`` to feed the scorer — the surrounding policy can + choose to treat that as a weakness or not. + + Args: + result: The attack result whose final response should be classified. + + Returns: + str: The label emitted by ``self._outcome_scorer``. + """ + if result.last_response is None: + return OutcomeScorer.UNSCORED + message = Message(message_pieces=[result.last_response], skip_validation=True) + return await self._outcome_scorer.resolve_outcome_async(message, objective=result.objective) + + def _category_key(self, result: AttackResult) -> str: + """Return the display-group key used to bucket this result by category.""" + return self.display_group or "" + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for this sweep step. + + Nests the wrapped attack's identifier under ``children["sweep_attack"]`` + and the outcome scorer under ``children["outcome_scorer"]`` so + hash drift in either propagates upward. + + Returns: + ComponentIdentifier: The frozen identity snapshot. + """ + from pyrit.identifiers import ComponentIdentifier + + return ComponentIdentifier.of( + self, + params={ + "atomic_attack_name": self.atomic_attack_name, + "outputs": list(self.outputs), + "weakness_label": self._weakness_label, + }, + children={ + "sweep_attack": self._atomic.get_identifier(), + "outcome_scorer": self._outcome_scorer.wrapped_scorer.get_identifier(), + }, + ) + + +class FilteredDeepDiveStep(ScenarioStep): + """ + Run each supplied ``AtomicAttack`` only if its category was flagged weak. + + Receives a ``weak_categories_ref`` callable that returns the live set + of weak categories (typically a closure over the sweep step's result). + For each wrapped atomic attack, ``display_group`` is checked against + the live set; non-matching atomics are skipped. Always emits the + single ``"done"`` outcome — deep-dive results are aggregated regardless + of their individual ``AttackOutcome``. + + Duck-types AtomicAttack-like attributes (``atomic_attack_name``, + ``display_group``, ``objectives``, ``seed_groups``) for orchestrator + bookkeeping. + """ + + _OUTPUTS: ClassVar[tuple[str, ...]] = ("done",) + + def __init__( + self, + *, + atomic_attack_name: str, + atomic_attacks: Sequence[AtomicAttack], + weak_categories_ref: Callable[[], set[str]], + max_concurrency: int = 1, + ) -> None: + """ + Initialize the filtered deep-dive step. + + Args: + atomic_attack_name (str): Step name for resume / diagnostics. + atomic_attacks (Sequence[AtomicAttack]): The candidate atomics + to run conditionally. + weak_categories_ref (Callable[[], set[str]]): Callable that + returns the live set of weak categories at dispatch time. + Wrapping the lookup in a callable (rather than passing a + set by reference) lets the sweep step build its weak-set + fresh on each scenario attempt without the deep-dive step + holding a stale reference. + max_concurrency (int): Forwarded to each ``AtomicAttack.run_async``. + + Raises: + ValueError: If ``atomic_attacks`` is empty. + """ + if not atomic_attacks: + raise ValueError("FilteredDeepDiveStep requires at least one atomic attack.") + + self.name = atomic_attack_name + self.outputs = list(self._OUTPUTS) + self._atomics = list(atomic_attacks) + self._weak_categories_ref = weak_categories_ref + self._max_concurrency = max_concurrency + + self.atomic_attack_name = atomic_attack_name + self.display_group = atomic_attack_name + # Aggregate seed groups across all wrapped atomics so the orchestrator + # sees a comprehensive view even when individual atomics will be skipped. + self.seed_groups = [g for atomic in self._atomics for g in atomic.seed_groups] + seen_objectives: set[str] = set() + self.objectives = [] + for atomic in self._atomics: + for objective in atomic.objectives: + if objective not in seen_objectives: + self.objectives.append(objective) + seen_objectives.add(objective) + + def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) -> None: + """ + Forward objective filtering to each wrapped atomic attack. + + Args: + remaining_objectives: Remaining objectives the orchestrator + wants this step to cover. + """ + for atomic in self._atomics: + atomic.filter_seed_groups_by_objectives(remaining_objectives=remaining_objectives) + self.seed_groups = [g for atomic in self._atomics for g in atomic.seed_groups] + self.objectives = list(remaining_objectives) + + async def process_async(self) -> ScenarioStepResult: + """ + Run each wrapped atomic conditionally on its category being flagged. + + Returns: + ScenarioStepResult: Outcome is always ``"done"``. The aggregated + ``attack_results`` contain results from every atomic that + was actually dispatched. ``metadata['skipped_categories']`` + lists categories that were not in the weak set; + ``metadata['dispatched_categories']`` lists the categories + actually exercised. + """ + weak = self._weak_categories_ref() + attack_results: list[AttackResult] = [] + dispatched: list[str] = [] + skipped: list[str] = [] + + for atomic in self._atomics: + category = atomic.display_group or atomic.atomic_attack_name + if category not in weak: + skipped.append(category) + continue + executor_result = await atomic.run_async( + max_concurrency=self._max_concurrency, + return_partial_on_failure=True, + ) + attack_results.extend(executor_result.completed_results) + dispatched.append(category) + + return ScenarioStepResult( + outcome="done", + attack_results=attack_results, + metadata={ + "dispatched_categories": dispatched, + "skipped_categories": skipped, + }, + ) + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for the deep-dive step. + + Nests each wrapped atomic's identifier under ``children["deep_dive_attacks"]`` + (sorted by ``atomic_attack_name`` for deterministic hash stability). + + Returns: + ComponentIdentifier: The frozen identity snapshot. + """ + from pyrit.identifiers import ComponentIdentifier + + nested_ids = [atomic.get_identifier() for atomic in sorted(self._atomics, key=lambda a: a.atomic_attack_name)] + return ComponentIdentifier.of( + self, + params={ + "atomic_attack_name": self.atomic_attack_name, + "outputs": list(self.outputs), + }, + children={"deep_dive_attacks": nested_ids}, + ) + + +class BroadSweepThenDeepDive(Scenario): + """ + Branching scenario: sweep with fast attacks, escalate only on weakness. + + Composes two phases into a non-linear ``StrategyGraph``: + + - **Sweep**: a ``CategoryAggregatingSweepStep`` runs one provided + single-turn ``AtomicAttack`` across every seed group and reports + which categories the model breached. + - **Deep dive**: a ``FilteredDeepDiveStep`` runs the provided multi-turn + ``AtomicAttack`` list — but only for categories the sweep flagged. + + The policy short-circuits to ``ALL_SAFE`` when the sweep finds no + weaknesses, so the deep-dive step is never invoked. This is the Phase 7 + validation artifact that the branching graph abstraction actually + expresses scenarios that today cannot be cleanly written with the + flat-loop pattern. + """ + + VERSION: int = 1 + + @apply_defaults + def __init__( + self, + *, + sweep_atomic_attack: AtomicAttack, + deep_dive_atomic_attacks: Sequence[AtomicAttack], + outcome_scorer: OutcomeScorer, + weakness_label: str = "safety_violation", + objective_scorer: Optional[TrueFalseScorer] = None, + scenario_result_id: Optional[str] = None, + ) -> None: + """ + Initialize the branching scenario. + + Args: + sweep_atomic_attack (AtomicAttack): A single fast attack run + across all seed groups during the sweep phase. + deep_dive_atomic_attacks (Sequence[AtomicAttack]): The atomics + to consider during the deep-dive phase. Each is gated by + its ``display_group`` matching a category flagged by the + sweep step. + outcome_scorer (OutcomeScorer): Wraps the per-response classifier + whose output labels each sweep result. Must declare + ``weakness_label`` as one of its outputs. + weakness_label (str): The label emitted by ``outcome_scorer`` + that signals a category breach. Defaults to + ``"safety_violation"``. + objective_scorer (TrueFalseScorer | None): Forwarded to the + base ``Scenario``. Defaults to ``outcome_scorer.wrapped_scorer`` + cast to ``TrueFalseScorer`` so dataset config bootstrap + stays consistent. + scenario_result_id (str | None): Optional ID of an existing + scenario result to resume. + + Raises: + ValueError: If ``deep_dive_atomic_attacks`` is empty. + """ + if not deep_dive_atomic_attacks: + raise ValueError("BroadSweepThenDeepDive requires at least one deep_dive_atomic_attack.") + + self._sweep_atomic = sweep_atomic_attack + self._deep_dive_atomics: list[AtomicAttack] = list(deep_dive_atomic_attacks) + self._outcome_scorer = outcome_scorer + self._weakness_label = weakness_label + + # Shared mutable handle the sweep step updates and the deep-dive step + # reads. Reset on each ``run_async`` via ``_build_execution_graph``. + self._weak_categories: set[str] = set() + + effective_objective_scorer = ( + objective_scorer if objective_scorer is not None else cast("TrueFalseScorer", outcome_scorer.wrapped_scorer) + ) + + super().__init__( + version=self.VERSION, + objective_scorer=effective_objective_scorer, + strategy_class=self.get_strategy_class(), + scenario_result_id=scenario_result_id, + ) + + @classmethod + def get_strategy_class(cls) -> type[ScenarioStrategy]: + """Return the (single-member) strategy enum class.""" + return BroadSweepThenDeepDiveStrategy + + @classmethod + def get_default_strategy(cls) -> ScenarioStrategy: + """Return the only strategy member.""" + return BroadSweepThenDeepDiveStrategy.DEFAULT + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + """ + Return an empty dataset configuration. + + The atomics are supplied directly via the constructor, so the + base scenario's auto-build from registered datasets is unused. + + Returns: + DatasetConfiguration: An empty configuration; this scenario + does not consume registered datasets. + """ + return DatasetConfiguration() + + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + """ + Return the atomics in canonical phase order: sweep first, deep dives after. + + The orchestrator uses this list as the resume-by-name keyspace. We + keep the order stable so resume can rehydrate deterministically. + + Returns: + list[AtomicAttack]: ``[sweep_atomic, *deep_dive_atomics]``. + """ + return [self._sweep_atomic, *self._deep_dive_atomics] + + def _build_execution_graph( # ty: ignore[invalid-method-override] + self, + *, + steps: Optional[Sequence[ScenarioStep]] = None, + ) -> StrategyGraph[ScenarioStep, SweepThenDeepDiveState]: + """ + Build the branching graph that drives the sweep → deep-dive flow. + + Ignores the orchestrator-supplied ``steps`` argument: this scenario + owns the partitioning of ``self._atomic_attacks`` into the sweep + and deep-dive phases and cannot be driven by a flat step list. If + the resume filter removes the sweep atomic (because it already + completed), the sweep step's wrapped attack will return its + previously-completed results from memory; the policy still runs + the classification pass to populate ``self._weak_categories``. + + Note: + This override widens ``StateT`` from ``int`` (the base-class + default for the linear policy) to a per-scenario enum. The + base method is an extension point, but the static return type + cannot be expressed without making ``Scenario`` generic over + ``StateT``. Long-term fix tracked as a Phase 9 cleanup; the + runtime contract (``StrategyGraph[ScenarioStep, Any]``) is + preserved. See plan.md for context. + + Args: + steps: Ignored. Present only to honor the base-class signature. + + Returns: + StrategyGraph[ScenarioStep, SweepThenDeepDiveState]: The branching + graph instance. + """ + # Reset shared mutable state for this attempt. + self._weak_categories = set() + + sweep_step = CategoryAggregatingSweepStep( + atomic_attack_name=self._sweep_atomic.atomic_attack_name, + atomic_attack=self._sweep_atomic, + outcome_scorer=self._outcome_scorer, + weakness_label=self._weakness_label, + max_concurrency=self._max_concurrency, + ) + deep_dive_step = FilteredDeepDiveStep( + atomic_attack_name=f"{self._sweep_atomic.atomic_attack_name}_deep_dive", + atomic_attacks=self._deep_dive_atomics, + weak_categories_ref=lambda: self._weak_categories, + max_concurrency=self._max_concurrency, + ) + + policy = self._build_branching_policy(sweep_step=sweep_step, deep_dive_step=deep_dive_step) + return StrategyGraph(policy=policy) + + def _build_branching_policy( + self, + *, + sweep_step: CategoryAggregatingSweepStep, + deep_dive_step: FilteredDeepDiveStep, + ) -> StrategyPolicy[ScenarioStep, SweepThenDeepDiveState]: + """ + Compose the two-action policy for the branching graph. + + States: + - ``SWEEPING`` → ``_sweep_action`` runs the sweep step. Emits + ``DEEP_DIVING`` if ``"found_weaknesses"`` was the sweep + outcome, else ``ALL_SAFE`` (terminal). + - ``DEEP_DIVING`` → ``_deep_dive_action`` runs the deep-dive + step. Always emits ``COMPLETE`` (terminal). + - ``COMPLETE`` and ``ALL_SAFE`` are terminal states. + + Args: + sweep_step: The sweep step to dispatch from ``SWEEPING``. + deep_dive_step: The deep-dive step to dispatch from + ``DEEP_DIVING``. + + Returns: + StrategyPolicy: A frozen policy with the two-state branching + action map. + """ + + async def _sweep_action( + graph: StrategyGraph[ScenarioStep, SweepThenDeepDiveState], + ) -> tuple[SweepThenDeepDiveState, ScenarioStepResult | None]: + graph.bind_current_step(step=sweep_step) + try: + base_result = await sweep_step.process_async() + # Update the closure-shared weak-categories set so the + # deep-dive step sees the fresh classification. + self._weak_categories = set(base_result.metadata.get("weak_categories", set())) + merged_metadata = {"step_name": sweep_step.name, **base_result.metadata} + result = ScenarioStepResult( + outcome=base_result.outcome, + attack_results=list(base_result.attack_results), + step_identifier=base_result.step_identifier, + metadata=merged_metadata, + ) + finally: + graph.bind_current_step(step=None) + + next_state = ( + SweepThenDeepDiveState.DEEP_DIVING + if base_result.outcome == "found_weaknesses" + else SweepThenDeepDiveState.ALL_SAFE + ) + return next_state, result + + async def _deep_dive_action( + graph: StrategyGraph[ScenarioStep, SweepThenDeepDiveState], + ) -> tuple[SweepThenDeepDiveState, ScenarioStepResult | None]: + graph.bind_current_step(step=deep_dive_step) + try: + base_result = await deep_dive_step.process_async() + merged_metadata = {"step_name": deep_dive_step.name, **base_result.metadata} + result = ScenarioStepResult( + outcome=base_result.outcome, + attack_results=list(base_result.attack_results), + step_identifier=base_result.step_identifier, + metadata=merged_metadata, + ) + finally: + graph.bind_current_step(step=None) + return SweepThenDeepDiveState.COMPLETE, result + + actions: dict[SweepThenDeepDiveState, PolicyAction[ScenarioStep, SweepThenDeepDiveState]] = { + SweepThenDeepDiveState.SWEEPING: _sweep_action, + SweepThenDeepDiveState.DEEP_DIVING: _deep_dive_action, + } + + return StrategyPolicy( + actions=actions, + initial_state=SweepThenDeepDiveState.SWEEPING, + terminal_states=frozenset({SweepThenDeepDiveState.COMPLETE, SweepThenDeepDiveState.ALL_SAFE}), + ) diff --git a/tests/unit/scenario/scenarios/airt/__init__.py b/tests/unit/scenario/scenarios/airt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py new file mode 100644 index 000000000..8e3ec23e6 --- /dev/null +++ b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py @@ -0,0 +1,622 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Phase 7 — coverage for ``BroadSweepThenDeepDive``. + +These tests pin the contract of the first real branching scenario: + +- ``CategoryAggregatingSweepStep`` classifies each ``AttackResult`` via the + ``OutcomeScorer`` and emits the right outcome label plus weak-category set. +- ``FilteredDeepDiveStep`` gates each wrapped atomic on its category being in + the live weak set; non-matching atomics are skipped and stamped into + ``metadata['skipped_categories']``. +- ``BroadSweepThenDeepDive._build_execution_graph`` produces a two-action + branching policy that walks SWEEPING -> {DEEP_DIVING -> COMPLETE | ALL_SAFE} + and resets ``self._weak_categories`` on every invocation. +- The closure-shared ``weak_categories_ref`` correctly threads sweep output + into the deep-dive step at policy-dispatch time. +""" + +from __future__ import annotations + +from typing import Any, cast +from unittest.mock import MagicMock, PropertyMock + +import pytest + +from pyrit.executor.attack.core import AttackExecutorResult +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import AttackOutcome, AttackResult, MessagePiece +from pyrit.scenario import DatasetConfiguration +from pyrit.scenario.core import AtomicAttack +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy +from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + BroadSweepThenDeepDive, + BroadSweepThenDeepDiveStrategy, + CategoryAggregatingSweepStep, + FilteredDeepDiveStep, + SweepThenDeepDiveState, +) +from pyrit.score import Scorer +from pyrit.score.decorators.outcome_scorer import OutcomeScorer + +_WEAKNESS_LABEL = "safety_violation" +_SAFE_LABEL = "safe" + + +def _save_results_to_memory(attack_results: list[AttackResult]) -> None: + # Memory persistence is intentionally skipped in mocks; these tests + # exercise the branching policy + classification path, not the + # persistence layer (covered by Phase 5 graph-execution tests). + return + + +def _make_scorer_id(name: str) -> ComponentIdentifier: + return ComponentIdentifier(class_name=name, class_module="tests.unit.scenario.scenarios.airt") + + +def _make_outcome_scorer(*, label_for: dict[str, str] | None = None) -> OutcomeScorer: + """Build an ``OutcomeScorer`` whose wrapped scorer maps response.text to a label. + + The wrapped scorer pulls the response text out of the supplied ``Message`` + and looks it up in ``label_for`` to decide whether the score predicate + should return weakness or safe. This lets each test feed deterministic + per-category labels without writing a full ``Scorer`` subclass. + """ + label_for = label_for or {} + scorer = MagicMock(spec=Scorer) + scorer.get_identifier.return_value = _make_scorer_id("MockScorer") + + async def _score_async(message: Any, *, objective: Any = None) -> list[Any]: + text = message.message_pieces[0].original_value if message.message_pieces else "" + label = label_for.get(text, _SAFE_LABEL) + score = MagicMock() + score.score_value = label + return [score] + + scorer.score_async = MagicMock(side_effect=_score_async) + + return OutcomeScorer( + wrapped_scorer=scorer, + outcome_map={ + _WEAKNESS_LABEL: lambda s: s.score_value == _WEAKNESS_LABEL, + _SAFE_LABEL: lambda s: s.score_value == _SAFE_LABEL, + }, + ) + + +def _attack_result(*, conversation_id: str, objective: str, response_text: str | None) -> AttackResult: + """Build an AttackResult with a real last_response piece (or None).""" + result = AttackResult( + conversation_id=conversation_id, + objective=objective, + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + result.atomic_attack_identifier = ComponentIdentifier( + class_name="MockAttack", + class_module="tests.unit.scenario.scenarios.airt", + params={"name": conversation_id}, + ) + if response_text is not None: + result.last_response = MessagePiece(role="assistant", original_value=response_text) + return result + + +def _make_atomic_mock( + *, + name: str, + display_group: str, + attack_results: list[AttackResult], + objectives: list[str] | None = None, +) -> MagicMock: + """Mock an ``AtomicAttack`` whose ``run_async`` returns the supplied results.""" + attack = MagicMock(spec=AtomicAttack) + attack.atomic_attack_name = name + attack.display_group = display_group + type(attack).objectives = PropertyMock( + return_value=list(objectives) if objectives is not None else [r.objective for r in attack_results] + ) + type(attack).seed_groups = PropertyMock(return_value=[]) + attack.get_identifier.return_value = ComponentIdentifier( + class_name="AtomicAttack", + class_module="tests.unit.scenario.scenarios.airt", + params={"name": name}, + ) + + async def _fake_run(*args: Any, **kwargs: Any) -> AttackExecutorResult: + _save_results_to_memory(attack_results) + return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) + + attack.run_async = MagicMock(side_effect=_fake_run) + return attack + + +@pytest.mark.usefixtures("patch_central_database") +class TestCategoryAggregatingSweepStep: + """Pin the sweep step's classification, outcome, and metadata contract.""" + + def test_init_rejects_weakness_label_not_in_outcomes(self) -> None: + scorer = _make_outcome_scorer() + atomic = _make_atomic_mock(name="sweep", display_group="cat-a", attack_results=[]) + with pytest.raises(ValueError, match="weakness_label"): + CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label="not_a_real_label", + ) + + def test_init_surfaces_atomic_attack_attributes(self) -> None: + scorer = _make_outcome_scorer() + atomic = _make_atomic_mock( + name="sweep-name", + display_group="harms", + attack_results=[], + objectives=["obj-1", "obj-2"], + ) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep-name", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + assert isinstance(step, ScenarioStep) + assert step.name == "sweep-name" + assert step.atomic_attack_name == "sweep-name" + assert step.display_group == "harms" + assert step.objectives == ["obj-1", "obj-2"] + assert _WEAKNESS_LABEL in step.outputs or step.outputs == list( + CategoryAggregatingSweepStep._OUTPUTS, + ) + + def test_filter_seed_groups_forwards_to_wrapped_atomic(self) -> None: + scorer = _make_outcome_scorer() + atomic = _make_atomic_mock(name="sweep", display_group="cat-a", attack_results=[]) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + step.filter_seed_groups_by_objectives(remaining_objectives=["obj-x"]) + atomic.filter_seed_groups_by_objectives.assert_called_once_with(remaining_objectives=["obj-x"]) + + async def test_process_emits_found_weaknesses_when_any_match(self) -> None: + scorer = _make_outcome_scorer( + label_for={"breach-A": _WEAKNESS_LABEL, "breach-B": _WEAKNESS_LABEL}, + ) + result_one = _attack_result( + conversation_id="c1", + objective="o1", + response_text="breach-A", + ) + result_two = _attack_result( + conversation_id="c2", + objective="o2", + response_text="breach-B", + ) + atomic = _make_atomic_mock( + name="sweep", + display_group="cat-a", + attack_results=[result_one, result_two], + ) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + + step_result = await step.process_async() + + # All results in this sweep share display_group "cat-a", so weak_categories + # is OR-aggregated across results. + assert step_result.outcome == "found_weaknesses" + assert step_result.metadata["weak_categories"] == {"cat-a"} + assert step_result.metadata["category_outcomes"]["cat-a"] == _WEAKNESS_LABEL + assert len(step_result.attack_results) == 2 + + async def test_process_marks_category_weak_even_when_only_one_result_is_weak(self) -> None: + """OR-semantics: a single weak result in a category flags the whole category.""" + scorer = _make_outcome_scorer( + label_for={"breach-A": _WEAKNESS_LABEL, "safe-text": _SAFE_LABEL}, + ) + result_weak = _attack_result( + conversation_id="c1", + objective="o1", + response_text="breach-A", + ) + result_safe = _attack_result( + conversation_id="c2", + objective="o2", + response_text="safe-text", + ) + atomic = _make_atomic_mock( + name="sweep", + display_group="cat-a", + attack_results=[result_weak, result_safe], + ) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + + step_result = await step.process_async() + + assert step_result.outcome == "found_weaknesses" + assert step_result.metadata["weak_categories"] == {"cat-a"} + + async def test_process_emits_all_safe_when_no_matches(self) -> None: + scorer = _make_outcome_scorer(label_for={"safe-text": _SAFE_LABEL}) + result_safe = _attack_result( + conversation_id="c1", + objective="o1", + response_text="safe-text", + ) + atomic = _make_atomic_mock( + name="sweep", + display_group="cat-a", + attack_results=[result_safe], + ) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + + step_result = await step.process_async() + + assert step_result.outcome == "all_safe" + assert step_result.metadata["weak_categories"] == set() + assert step_result.metadata["category_outcomes"] == {"cat-a": _SAFE_LABEL} + + async def test_process_returns_unscored_when_last_response_is_none(self) -> None: + scorer = _make_outcome_scorer() + result_no_response = _attack_result(conversation_id="c1", objective="o1", response_text=None) + atomic = _make_atomic_mock( + name="sweep", + display_group="cat-a", + attack_results=[result_no_response], + ) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + + step_result = await step.process_async() + + assert step_result.metadata["category_outcomes"]["cat-a"] == OutcomeScorer.UNSCORED + assert step_result.outcome == "all_safe" + + async def test_process_empty_results_emits_all_safe(self) -> None: + scorer = _make_outcome_scorer() + atomic = _make_atomic_mock(name="sweep", display_group="cat-a", attack_results=[]) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + + step_result = await step.process_async() + + assert step_result.outcome == "all_safe" + assert step_result.metadata["weak_categories"] == set() + assert step_result.attack_results == [] + + def test_build_identifier_nests_attack_and_scorer(self) -> None: + scorer = _make_outcome_scorer() + atomic = _make_atomic_mock(name="sweep", display_group="cat-a", attack_results=[]) + step = CategoryAggregatingSweepStep( + atomic_attack_name="sweep", + atomic_attack=atomic, + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + identifier = step._build_identifier() + assert "sweep_attack" in identifier.children + assert "outcome_scorer" in identifier.children + assert identifier.params["weakness_label"] == _WEAKNESS_LABEL + + +@pytest.mark.usefixtures("patch_central_database") +class TestFilteredDeepDiveStep: + """Pin the deep-dive step's gating, metadata, and aggregation contract.""" + + def test_init_rejects_empty_atomic_list(self) -> None: + with pytest.raises(ValueError, match="at least one"): + FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[], + weak_categories_ref=lambda: set(), + ) + + def test_init_dedupes_objectives_across_atomics(self) -> None: + a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[], objectives=["o1", "o2"]) + b = _make_atomic_mock(name="b", display_group="cat-b", attack_results=[], objectives=["o2", "o3"]) + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[a, b], + weak_categories_ref=lambda: set(), + ) + assert isinstance(step, ScenarioStep) + assert step.objectives == ["o1", "o2", "o3"] + + async def test_process_dispatches_only_flagged_categories(self) -> None: + result_a = _attack_result(conversation_id="da", objective="oa", response_text="ra") + result_b = _attack_result(conversation_id="db", objective="ob", response_text="rb") + atomic_a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[result_a]) + atomic_b = _make_atomic_mock(name="b", display_group="cat-b", attack_results=[result_b]) + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[atomic_a, atomic_b], + weak_categories_ref=lambda: {"cat-a"}, + ) + + step_result = await step.process_async() + + atomic_a.run_async.assert_called_once() + atomic_b.run_async.assert_not_called() + assert step_result.outcome == "done" + assert step_result.metadata["dispatched_categories"] == ["cat-a"] + assert step_result.metadata["skipped_categories"] == ["cat-b"] + assert step_result.attack_results == [result_a] + + async def test_process_skips_all_when_weak_set_empty(self) -> None: + result_a = _attack_result(conversation_id="da", objective="oa", response_text="ra") + atomic_a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[result_a]) + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[atomic_a], + weak_categories_ref=lambda: set(), + ) + + step_result = await step.process_async() + + atomic_a.run_async.assert_not_called() + assert step_result.outcome == "done" + assert step_result.metadata["dispatched_categories"] == [] + assert step_result.metadata["skipped_categories"] == ["cat-a"] + assert step_result.attack_results == [] + + async def test_process_reads_weak_categories_live_at_dispatch_time(self) -> None: + """The closure must be re-evaluated at dispatch time, not at init.""" + weak: set[str] = set() + result_a = _attack_result(conversation_id="da", objective="oa", response_text="ra") + atomic_a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[result_a]) + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[atomic_a], + weak_categories_ref=lambda: weak, + ) + # Mutate the set AFTER init but BEFORE process_async runs. + weak.add("cat-a") + + step_result = await step.process_async() + + atomic_a.run_async.assert_called_once() + assert step_result.metadata["dispatched_categories"] == ["cat-a"] + + def test_filter_seed_groups_forwards_to_each_atomic(self) -> None: + a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + b = _make_atomic_mock(name="b", display_group="cat-b", attack_results=[]) + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[a, b], + weak_categories_ref=lambda: set(), + ) + step.filter_seed_groups_by_objectives(remaining_objectives=["obj-z"]) + a.filter_seed_groups_by_objectives.assert_called_once_with(remaining_objectives=["obj-z"]) + b.filter_seed_groups_by_objectives.assert_called_once_with(remaining_objectives=["obj-z"]) + assert step.objectives == ["obj-z"] + + def test_build_identifier_sorts_nested_atomics(self) -> None: + a = _make_atomic_mock(name="b-second", display_group="cat-b", attack_results=[]) + b = _make_atomic_mock(name="a-first", display_group="cat-a", attack_results=[]) + + step_ab = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[a, b], + weak_categories_ref=lambda: set(), + ) + step_ba = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[b, a], + weak_categories_ref=lambda: set(), + ) + + id_ab = step_ab._build_identifier() + id_ba = step_ba._build_identifier() + + nested_ab = id_ab.children["deep_dive_attacks"] + assert isinstance(nested_ab, list) + # Sorted by atomic_attack_name, so 'a-first' must precede 'b-second'. + assert [c.params["name"] for c in nested_ab] == ["a-first", "b-second"] + assert id_ab == id_ba + + +@pytest.mark.usefixtures("patch_central_database") +class TestBroadSweepThenDeepDive: + """End-to-end coverage of the branching scenario and its graph.""" + + @staticmethod + def _build_scenario( + *, + sweep_response_text: str, + scorer_label_for: dict[str, str], + deep_dive_display_groups: list[str], + ) -> tuple[ + BroadSweepThenDeepDive, + MagicMock, + list[MagicMock], + ]: + scorer = _make_outcome_scorer(label_for=scorer_label_for) + sweep_result = _attack_result( + conversation_id="sweep-c", + objective="sweep-obj", + response_text=sweep_response_text, + ) + sweep_atomic = _make_atomic_mock( + name="sweep-atomic", + display_group="cat-a", + attack_results=[sweep_result], + ) + deep_dives = [ + _make_atomic_mock( + name=f"deep-{i}", + display_group=group, + attack_results=[ + _attack_result( + conversation_id=f"deep-c-{i}", + objective=f"deep-obj-{i}", + response_text=f"deep-text-{i}", + ) + ], + ) + for i, group in enumerate(deep_dive_display_groups) + ] + scenario = BroadSweepThenDeepDive( + sweep_atomic_attack=cast("AtomicAttack", sweep_atomic), + deep_dive_atomic_attacks=[cast("AtomicAttack", d) for d in deep_dives], + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + return scenario, sweep_atomic, deep_dives + + def test_constructor_rejects_empty_deep_dive_list(self) -> None: + scorer = _make_outcome_scorer() + sweep = _make_atomic_mock(name="sweep", display_group="cat-a", attack_results=[]) + with pytest.raises(ValueError, match="deep_dive_atomic_attack"): + BroadSweepThenDeepDive( + sweep_atomic_attack=cast("AtomicAttack", sweep), + deep_dive_atomic_attacks=[], + outcome_scorer=scorer, + ) + + def test_strategy_metadata(self) -> None: + assert BroadSweepThenDeepDive.get_strategy_class() is BroadSweepThenDeepDiveStrategy + assert BroadSweepThenDeepDive.get_default_strategy() is BroadSweepThenDeepDiveStrategy.DEFAULT + assert isinstance(BroadSweepThenDeepDive.default_dataset_config(), DatasetConfiguration) + + async def test_get_atomic_attacks_returns_canonical_order(self) -> None: + scenario, sweep_atomic, deep_dives = self._build_scenario( + sweep_response_text="safe", + scorer_label_for={"safe": _SAFE_LABEL}, + deep_dive_display_groups=["cat-a", "cat-b"], + ) + atomics = await scenario._get_atomic_attacks_async() + assert atomics[0] is sweep_atomic + assert atomics[1:] == deep_dives + + def test_build_execution_graph_returns_branching_strategy_graph(self) -> None: + scenario, _, _ = self._build_scenario( + sweep_response_text="safe", + scorer_label_for={"safe": _SAFE_LABEL}, + deep_dive_display_groups=["cat-a"], + ) + graph = scenario._build_execution_graph() + assert isinstance(graph, StrategyGraph) + policy = graph._policy + assert isinstance(policy, StrategyPolicy) + assert policy.initial_state == SweepThenDeepDiveState.SWEEPING + assert SweepThenDeepDiveState.COMPLETE in policy.terminal_states + assert SweepThenDeepDiveState.ALL_SAFE in policy.terminal_states + # Both non-terminal states must have an action. + assert SweepThenDeepDiveState.SWEEPING in policy.actions + assert SweepThenDeepDiveState.DEEP_DIVING in policy.actions + + def test_build_execution_graph_resets_weak_categories(self) -> None: + scenario, _, _ = self._build_scenario( + sweep_response_text="safe", + scorer_label_for={"safe": _SAFE_LABEL}, + deep_dive_display_groups=["cat-a"], + ) + scenario._weak_categories = {"stale-from-previous-attempt"} + scenario._build_execution_graph() + assert scenario._weak_categories == set() + + async def test_event_loop_with_weakness_drives_deep_dive_to_complete(self) -> None: + scenario, sweep_atomic, deep_dives = self._build_scenario( + sweep_response_text="breach-A", + scorer_label_for={"breach-A": _WEAKNESS_LABEL}, + deep_dive_display_groups=["cat-a", "cat-b"], + ) + graph = scenario._build_execution_graph() + + results: list[ScenarioStepResult] = [result async for result in graph.event_loop_async()] + + assert graph.current_state == SweepThenDeepDiveState.COMPLETE + assert len(results) == 2 + # Sweep step ran once. + sweep_atomic.run_async.assert_called_once() + # Only the cat-a deep-dive should have been dispatched. + deep_dives[0].run_async.assert_called_once() + deep_dives[1].run_async.assert_not_called() + assert results[1].metadata["dispatched_categories"] == ["cat-a"] + assert results[1].metadata["skipped_categories"] == ["cat-b"] + + async def test_event_loop_short_circuits_to_all_safe(self) -> None: + scenario, sweep_atomic, deep_dives = self._build_scenario( + sweep_response_text="safe-text", + scorer_label_for={"safe-text": _SAFE_LABEL}, + deep_dive_display_groups=["cat-a", "cat-b"], + ) + graph = scenario._build_execution_graph() + + results: list[ScenarioStepResult] = [result async for result in graph.event_loop_async()] + + assert graph.current_state == SweepThenDeepDiveState.ALL_SAFE + # Only the sweep step's result is in history; deep dive never ran. + assert len(results) == 1 + sweep_atomic.run_async.assert_called_once() + for d in deep_dives: + d.run_async.assert_not_called() + assert results[0].outcome == "found_weaknesses" or results[0].outcome == "all_safe" + assert results[0].metadata["weak_categories"] == set() + + async def test_weak_categories_propagate_from_sweep_to_deep_dive(self) -> None: + """The closure-shared set must hand the right categories to the deep dive.""" + scenario, _, deep_dives = self._build_scenario( + sweep_response_text="breach-A", + scorer_label_for={"breach-A": _WEAKNESS_LABEL}, + deep_dive_display_groups=["cat-a", "cat-b", "cat-c"], + ) + graph = scenario._build_execution_graph() + + async for _ in graph.event_loop_async(): + pass + + # The sweep step's display_group is "cat-a", so only cat-a is flagged. + deep_dives[0].run_async.assert_called_once() + deep_dives[1].run_async.assert_not_called() + deep_dives[2].run_async.assert_not_called() + assert scenario._weak_categories == {"cat-a"} + + async def test_consecutive_runs_reset_state(self) -> None: + """Re-invoking _build_execution_graph clears prior weak_categories cleanly.""" + scenario, _, _ = self._build_scenario( + sweep_response_text="safe", + scorer_label_for={"safe": _SAFE_LABEL}, + deep_dive_display_groups=["cat-a"], + ) + scenario._weak_categories = {"ghost-category"} + graph1 = scenario._build_execution_graph() + assert scenario._weak_categories == set() + async for _ in graph1.event_loop_async(): + pass + + scenario._weak_categories = {"another-ghost"} + graph2 = scenario._build_execution_graph() + assert scenario._weak_categories == set() + async for _ in graph2.event_loop_async(): + pass From 325bea9ace5c02f2c2918e04f8ed341901e0e868 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 14:00:06 -0700 Subject: [PATCH 14/42] TEST: augment Scenario.run_async graph rewire coverage (Phase 10e) Add 4 tests covering Phase 5 (commit 952311d9) gaps in Scenario.run_async: - TestStepIdentifierStampingNoDuplication (2 tests): verify the step_identifier stamping path uses update_attack_result_by_id and never inserts duplicate rows, both for single- and multi-result steps. - TestExecutionGraphRebuildOnRetry (1 test): verify the execution graph is rebuilt from the resume-filtered remaining steps after a partial failure, so terminal_states shrinks on retry. - TestFactoryAtomicAttackGraphIntegration (1 test): end-to-end integration through AttackTechniqueFactory -> AttackTechnique -> AtomicAttack -> StrategyGraph execution path, asserting the factory-built attack is the one the executor receives and that step_identifier is stamped on the resulting AttackResult. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenario/test_attack_technique_factory.py | 91 ++++++++++++ .../scenario/test_scenario_graph_execution.py | 137 +++++++++++++++++- 2 files changed, 220 insertions(+), 8 deletions(-) diff --git a/tests/unit/scenario/test_attack_technique_factory.py b/tests/unit/scenario/test_attack_technique_factory.py index 6b78a282a..2e19f5605 100644 --- a/tests/unit/scenario/test_attack_technique_factory.py +++ b/tests/unit/scenario/test_attack_technique_factory.py @@ -572,3 +572,94 @@ def test_unwrap_non_type_annotation_returns_none(self): """A non-type annotation (e.g., string forward ref) returns None.""" result = AttackTechniqueFactory._unwrap_optional("SomeForwardRef") assert result is None + + +@pytest.mark.usefixtures("patch_central_database") +class TestFactoryAtomicAttackGraphIntegration: + """End-to-end wiring: factory → AttackTechnique → AtomicAttack → Scenario graph. + + Pins that an ``AttackTechnique`` produced by ``AttackTechniqueFactory.create()`` + plugs into ``AtomicAttack`` and is driven through ``Scenario.run_async``'s + Phase 5 ``StrategyGraph`` orchestrator without per-component glue. + """ + + async def test_factory_produced_technique_runs_through_scenario_graph(self): + from unittest.mock import AsyncMock, patch + + from pyrit.executor.attack import AttackExecutor + from pyrit.executor.attack.core import AttackExecutorResult + from pyrit.models import ( + AttackOutcome, + AttackResult, + SeedAttackGroup, + SeedObjective, + ) + from pyrit.scenario import AtomicAttack, ScenarioResult + + # Local import to keep collection-time light and avoid coupling with the + # factory-only test suite. + from tests.unit.scenario.test_scenario_graph_execution import _GraphConcreteScenario + + objective_target = MagicMock(spec=PromptTarget) + objective_target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", + class_module="tests.unit", + ) + scoring_config = MagicMock(spec=AttackScoringConfig) + + factory = AttackTechniqueFactory(attack_class=_StubAttack) + technique = factory.create( + objective_target=objective_target, + attack_scoring_config=scoring_config, + ) + assert isinstance(technique, AttackTechnique) + assert isinstance(technique.attack, _StubAttack) + + seed_group = SeedAttackGroup(seeds=[SeedObjective(value="integration_obj")]) + atomic = AtomicAttack( + atomic_attack_name="integration_step", + attack_technique=technique, + seed_groups=[seed_group], + ) + + canned_result = AttackResult( + conversation_id="conv-integration", + objective="integration_obj", + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + scenario = _GraphConcreteScenario( + name="FactoryIntegration", + version=1, + atomic_attacks_to_return=[atomic], + ) + await scenario.initialize_async(objective_target=objective_target) + + # The real AttackExecutor persists each AttackResult before returning. Our patched + # executor must do the same so the scenario can rehydrate results from memory at + # ``get_scenario_results`` time. + async def _fake_execute(*args, **kwargs): + from pyrit.memory import CentralMemory + + CentralMemory.get_memory_instance().add_attack_results_to_memory(attack_results=[canned_result]) + return AttackExecutorResult( + completed_results=[canned_result], + incomplete_objectives=[], + input_indices=[0], + ) + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.side_effect = _fake_execute + result = await scenario.run_async() + + assert isinstance(result, ScenarioResult) + assert "integration_step" in result.attack_results + assert len(result.attack_results["integration_step"]) == 1 + # The orchestrator must have stamped a step_identifier on the factory-produced + # technique's result during graph execution. + stamped = result.attack_results["integration_step"][0].step_identifier + assert stamped is not None + assert stamped.class_name == "ScenarioStep" + # The executor should have received the very ``_StubAttack`` instance the factory built. + forwarded_attack = mock_exec.call_args.kwargs["attack"] + assert forwarded_attack is technique.attack diff --git a/tests/unit/scenario/test_scenario_graph_execution.py b/tests/unit/scenario/test_scenario_graph_execution.py index bd34e9c46..e60cb2966 100644 --- a/tests/unit/scenario/test_scenario_graph_execution.py +++ b/tests/unit/scenario/test_scenario_graph_execution.py @@ -158,7 +158,9 @@ def test_raises_when_explicit_steps_empty(self, mock_objective_target): async def test_default_graph_terminates_at_len_steps(self, mock_objective_target): attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(3)] scenario = _GraphConcreteScenario( - name="Default", version=1, atomic_attacks_to_return=attacks, + name="Default", + version=1, + atomic_attacks_to_return=attacks, ) await scenario.initialize_async(objective_target=mock_objective_target) @@ -173,7 +175,9 @@ async def test_default_graph_terminates_at_len_steps(self, mock_objective_target async def test_explicit_steps_override_atomic_attacks(self, mock_objective_target): attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(3)] scenario = _GraphConcreteScenario( - name="Default", version=1, atomic_attacks_to_return=attacks, + name="Default", + version=1, + atomic_attacks_to_return=attacks, ) await scenario.initialize_async(objective_target=mock_objective_target) @@ -194,7 +198,9 @@ def test_graph_and_history_are_empty_before_run(self): async def test_run_async_populates_graph_and_history(self, mock_objective_target): attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] scenario = _GraphConcreteScenario( - name="Populated", version=1, atomic_attacks_to_return=attacks, + name="Populated", + version=1, + atomic_attacks_to_return=attacks, ) await scenario.initialize_async(objective_target=mock_objective_target) @@ -214,7 +220,9 @@ class TestStepIdentifierStamping: async def test_each_result_has_step_identifier(self, mock_objective_target): attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] scenario = _GraphConcreteScenario( - name="Stamping", version=1, atomic_attacks_to_return=attacks, + name="Stamping", + version=1, + atomic_attacks_to_return=attacks, ) await scenario.initialize_async(objective_target=mock_objective_target) @@ -250,7 +258,9 @@ async def _run_returning_stamped(*args, **kwargs): attack.run_async = MagicMock(side_effect=_run_returning_stamped) scenario = _GraphConcreteScenario( - name="Pre-stamped", version=1, atomic_attacks_to_return=[attack], + name="Pre-stamped", + version=1, + atomic_attacks_to_return=[attack], ) await scenario.initialize_async(objective_target=mock_objective_target) @@ -265,6 +275,111 @@ async def _run_returning_stamped(*args, **kwargs): assert stamped.params.get("eval_version") == 99 +@pytest.mark.usefixtures("patch_central_database") +class TestStepIdentifierStampingNoDuplication: + """The orchestrator's per-step ``update_attack_result_by_id`` enrichment must + not introduce duplicate ``AttackResultEntry`` rows. + + Phase 5 routes results through ``StrategyGraph`` and enriches each result with + a ``step_identifier`` via ``update_attack_result_by_id``. The inner attack is + the sole insert site (mirrored here by ``_save_results_to_memory``); the + orchestrator only updates. This regression test guards against accidentally + flipping the update into an insert. + """ + + async def test_no_duplicate_attack_results_after_run(self, mock_objective_target): + # Two atomic attacks, one result each → memory should hold exactly 2 rows. + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] + scenario = _GraphConcreteScenario( + name="Dedup", + version=1, + atomic_attacks_to_return=attacks, + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + await scenario.run_async() + + memory = CentralMemory.get_memory_instance() + persisted = memory.get_attack_results() + assert len(persisted) == 2 + # Each persisted row must carry the step_identifier stamped by the orchestrator. + for ar in persisted: + assert ar.step_identifier is not None + assert ar.step_identifier.class_name == "ScenarioStep" + + async def test_no_duplicate_results_for_multi_result_step(self, mock_objective_target): + # One atomic attack returning two results — still exactly two rows after stamping. + result_a = _sample_result(0) + result_b = _sample_result(1) + attack = _make_atomic_attack_mock("multi", result_a) + + async def _run_multi(*args, **kwargs): + _save_results_to_memory([result_a, result_b]) + return AttackExecutorResult(completed_results=[result_a, result_b], incomplete_objectives=[]) + + attack.run_async = MagicMock(side_effect=_run_multi) + + scenario = _GraphConcreteScenario( + name="DedupMulti", + version=1, + atomic_attacks_to_return=[attack], + ) + await scenario.initialize_async(objective_target=mock_objective_target) + + await scenario.run_async() + + persisted = CentralMemory.get_memory_instance().get_attack_results() + assert len(persisted) == 2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestExecutionGraphRebuildOnRetry: + """The execution graph is rebuilt from resume-filtered steps on each attempt. + + After a failed attempt, the next attempt's ``execution_graph`` must reflect + only the still-pending steps. This pins the Phase 5 contract that + ``_build_execution_graph(steps=remaining_attacks)`` is invoked per attempt. + """ + + async def test_graph_terminal_states_shrink_after_partial_success(self, mock_objective_target): + # Attack 1 always succeeds; attack 2 fails once then succeeds. + attack_success = _make_atomic_attack_mock("a_success", _sample_result(0)) + + call_count = [0] + result_second = _sample_result(1) + + async def _run_flaky(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("first attempt failure") + _save_results_to_memory([result_second]) + return AttackExecutorResult(completed_results=[result_second], incomplete_objectives=[]) + + attack_flaky = _make_atomic_attack_mock("a_flaky", result_second) + attack_flaky.run_async = MagicMock(side_effect=_run_flaky) + # Mock the resume filter; orchestrator drops attacks whose objectives are all done. + attack_success.filter_seed_groups_by_objectives = MagicMock() + attack_flaky.filter_seed_groups_by_objectives = MagicMock() + + scenario = _GraphConcreteScenario( + name="RebuildOnRetry", + version=1, + atomic_attacks_to_return=[attack_success, attack_flaky], + ) + await scenario.initialize_async(objective_target=mock_objective_target, max_retries=1) + + await scenario.run_async() + + # After retry, the graph should reflect only the one remaining step (attack_flaky). + # ``execution_graph`` holds the most-recent attempt's graph. + graph = scenario.execution_graph + assert graph is not None + assert graph.policy.terminal_states == frozenset({1}) + # Sanity: the flaky attack ran twice (initial + retry); the success attack only once. + assert call_count[0] == 2 + attack_success.run_async.assert_called_once() + + @pytest.mark.usefixtures("patch_central_database") class TestMaxConcurrencyPropagation: """``max_concurrency`` flows from the scenario through the default linear policy.""" @@ -272,7 +387,9 @@ class TestMaxConcurrencyPropagation: async def test_atomic_attack_receives_scenario_max_concurrency(self, mock_objective_target): attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] scenario = _GraphConcreteScenario( - name="Concurrency", version=1, atomic_attacks_to_return=attacks, + name="Concurrency", + version=1, + atomic_attacks_to_return=attacks, ) await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=7) @@ -300,7 +417,9 @@ async def _run_partial(*args, **kwargs): attack.run_async = MagicMock(side_effect=_run_partial) scenario = _GraphConcreteScenario( - name="Partial", version=1, atomic_attacks_to_return=[attack], + name="Partial", + version=1, + atomic_attacks_to_return=[attack], ) await scenario.initialize_async(objective_target=mock_objective_target) @@ -383,7 +502,9 @@ async def test_custom_step_routed_through_process_async(self, mock_objective_tar step_b = _CountingStep(name="custom_b") scenario = _CustomStepScenario( - steps=[step_a, step_b], name="Custom-steps", version=1, + steps=[step_a, step_b], + name="Custom-steps", + version=1, ) await scenario.initialize_async(objective_target=mock_objective_target) From aec0bd00b6d2eecb130f5d1f151f7e2e0e18d22d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 14:33:34 -0700 Subject: [PATCH 15/42] Apply ruff format to scenario core refactor Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/attack_result.py | 4 +--- pyrit/scenario/core/scenario.py | 9 ++------- pyrit/scenario/core/strategy_graph.py | 13 +++---------- pyrit/score/decorators/outcome_scorer.py | 3 +-- .../identifiers/test_step_evaluation_identifier.py | 5 +---- 5 files changed, 8 insertions(+), 26 deletions(-) diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index ea388e8de..65c939fef 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -282,9 +282,7 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult: else None ), step_identifier=( - ComponentIdentifier.from_dict(data["step_identifier"]) - if data.get("step_identifier") - else None + ComponentIdentifier.from_dict(data["step_identifier"]) if data.get("step_identifier") else None ), last_response=(MessagePiece.from_dict(data["last_response"]) if data.get("last_response") else None), last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 7d23e035d..e50a061a2 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -1129,9 +1129,7 @@ def _build_execution_graph( ) return StrategyGraph(policy=self._build_default_linear_policy(steps=effective_steps)) - def _build_default_linear_policy( - self, *, steps: Sequence[ScenarioStep] - ) -> StrategyPolicy[ScenarioStep, int]: + def _build_default_linear_policy(self, *, steps: Sequence[ScenarioStep]) -> StrategyPolicy[ScenarioStep, int]: """ Build a linear-traversal policy that preserves scenario-level execution params. @@ -1404,10 +1402,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: # before truncation so it survives the DB round-trip, then push the enriched # identifier back to the AttackResultEntry row by attack_result_id. for attack_result in step_result.attack_results: - if ( - attack_result.step_identifier is None - and attack_result.atomic_attack_identifier is not None - ): + if attack_result.step_identifier is None and attack_result.atomic_attack_identifier is not None: new_identifier = build_step_identifier( step_name=step_name, outcome=step_result.outcome, diff --git a/pyrit/scenario/core/strategy_graph.py b/pyrit/scenario/core/strategy_graph.py index f426f041a..48bd342e3 100644 --- a/pyrit/scenario/core/strategy_graph.py +++ b/pyrit/scenario/core/strategy_graph.py @@ -83,16 +83,11 @@ def __post_init__(self) -> None: if not self.terminal_states: raise ValueError("StrategyPolicy requires at least one terminal state.") if self.initial_state in self.terminal_states: - raise ValueError( - f"initial_state {self.initial_state!r} is in terminal_states; " - f"the graph would do no work." - ) + raise ValueError(f"initial_state {self.initial_state!r} is in terminal_states; the graph would do no work.") overlap = [state for state in self.actions if state in self.terminal_states] if overlap: - raise ValueError( - f"Terminal states must not appear in actions: {overlap!r}." - ) + raise ValueError(f"Terminal states must not appear in actions: {overlap!r}.") object.__setattr__(self, "actions", MappingProxyType(dict(self.actions))) object.__setattr__(self, "terminal_states", frozenset(self.terminal_states)) @@ -115,9 +110,7 @@ def get_action(self, *, state: StateT) -> PolicyAction[StepT, StateT]: return self.actions[state] except KeyError: known = ", ".join(sorted(str(s) for s in self.actions)) - raise KeyError( - f"No action defined for state {state!r}. Known states: {known or '(none)'}." - ) from None + raise KeyError(f"No action defined for state {state!r}. Known states: {known or '(none)'}.") from None def is_terminal(self, *, state: StateT) -> bool: """Return ``True`` if ``state`` is a terminal state.""" diff --git a/pyrit/score/decorators/outcome_scorer.py b/pyrit/score/decorators/outcome_scorer.py index 9bce0eedd..3115e3e9b 100644 --- a/pyrit/score/decorators/outcome_scorer.py +++ b/pyrit/score/decorators/outcome_scorer.py @@ -66,8 +66,7 @@ def __init__( raise ValueError("OutcomeScorer requires a non-empty outcome_map.") if self.UNSCORED in outcome_map: raise ValueError( - f"Label {self.UNSCORED!r} is reserved as the no-match sentinel " - f"and may not appear in outcome_map." + f"Label {self.UNSCORED!r} is reserved as the no-match sentinel and may not appear in outcome_map." ) self._wrapped_scorer = wrapped_scorer diff --git a/tests/unit/identifiers/test_step_evaluation_identifier.py b/tests/unit/identifiers/test_step_evaluation_identifier.py index 082adbdc1..379ccf2f5 100644 --- a/tests/unit/identifiers/test_step_evaluation_identifier.py +++ b/tests/unit/identifiers/test_step_evaluation_identifier.py @@ -119,10 +119,7 @@ def test_nested_objective_target_operational_params_ignored(self): outcome="done", attack_execution_identifiers=[noisy_atomic], ) - assert ( - StepEvaluationIdentifier(noisy).eval_hash - == StepEvaluationIdentifier(baseline).eval_hash - ) + assert StepEvaluationIdentifier(noisy).eval_hash == StepEvaluationIdentifier(baseline).eval_hash def test_nested_objective_target_temperature_change_changes_hash(self): hot = StepEvaluationIdentifier(_build_step(target_temp=0.7)).eval_hash From be970a36c3362807b499b5551d6b76adcbe5ac72 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 14:55:36 -0700 Subject: [PATCH 16/42] finish BaselinePolicy -> BaselineAttackPolicy rename after main merge Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/core/__init__.py | 2 +- pyrit/scenario/scenarios/adaptive/adaptive_scenario.py | 4 ++-- pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py | 2 +- .../unit/scenario/scenarios/adaptive/test_text_adaptive.py | 6 +++--- tests/unit/scenario/test_scenario_graph_execution.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index c8e6aa4cf..86f1433a4 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -8,7 +8,7 @@ from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory, ScorerOverridePolicy from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration -from pyrit.scenario.core.scenario import BaselinePolicy, Scenario +from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py index 9b0588e01..b9b2bf87e 100644 --- a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, ClassVar, cast from pyrit.executor.attack import AttackScoringConfig -from pyrit.scenario.core.scenario import BaselinePolicy, Scenario +from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep @@ -57,7 +57,7 @@ class AdaptiveScenario(Scenario): rehydration are handled here. """ - BASELINE_POLICY: ClassVar[BaselinePolicy] = BaselinePolicy.Forbidden + BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden #: Subclasses must declare a scenario version for memory bookkeeping. VERSION: ClassVar[int] diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index b2d52ca01..9a23cf1c4 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -66,7 +66,7 @@ class SweepThenDeepDiveState(str, Enum): The three states that drive ``BroadSweepThenDeepDive``. Inherits ``str`` for the canonical ``(str, Enum)`` design language used - by ``ScenarioCoreState``, ``BaselinePolicy``, and ``AttackOutcome``: the + by ``ScenarioCoreState``, ``BaselineAttackPolicy``, and ``AttackOutcome``: the enum members serialize naturally in identifiers and logs. The two terminal states are distinct so that downstream consumers can diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index bd5788d73..bea5dba6c 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -14,7 +14,7 @@ from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.dataset_configuration import DatasetConfiguration -from pyrit.scenario.core.scenario import BaselinePolicy +from pyrit.scenario.core.scenario import BaselineAttackPolicy from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep @@ -113,7 +113,7 @@ def test_version(self): assert TextAdaptive.VERSION == 1 def test_baseline_forbidden(self): - assert TextAdaptive.BASELINE_POLICY is BaselinePolicy.Forbidden + assert TextAdaptive.BASELINE_ATTACK_POLICY is BaselineAttackPolicy.Forbidden def test_default_dataset_config(self): config = TextAdaptive.default_dataset_config() @@ -501,7 +501,7 @@ def test_memory_load_failure_is_swallowed(self, mock_objective_scorer): @pytest.mark.usefixtures(*FIXTURES) -class TestTextAdaptiveBaselinePolicy: +class TestTextAdaptiveBaselineAttackPolicy: async def test_initialize_async_rejects_explicit_baseline(self, mock_objective_target, mock_objective_scorer): groups = {"violence": [_make_seed_group(value="obj", harm_categories=["violence"])]} with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): diff --git a/tests/unit/scenario/test_scenario_graph_execution.py b/tests/unit/scenario/test_scenario_graph_execution.py index e60cb2966..ea8583dc4 100644 --- a/tests/unit/scenario/test_scenario_graph_execution.py +++ b/tests/unit/scenario/test_scenario_graph_execution.py @@ -21,7 +21,7 @@ from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioResult -from pyrit.scenario.core import AtomicAttack, BaselinePolicy, Scenario, ScenarioStrategy +from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy, Scenario, ScenarioStrategy from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.strategy_graph import ( PolicyAction, @@ -85,7 +85,7 @@ class _GraphConcreteScenario(Scenario): test_scenario fixture. """ - BASELINE_POLICY: ClassVar[BaselinePolicy] = BaselinePolicy.Forbidden + BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden def __init__(self, atomic_attacks_to_return=None, **kwargs): class _TestStrategy(ScenarioStrategy): From fee626d36c47bd617358e90840f0eeac6c3e58f4 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 15:15:14 -0700 Subject: [PATCH 17/42] FEAT: Phase 8a schema foundation + forward waterfall --- .../attack_technique_registry.py | 1 + pyrit/scenario/__init__.py | 8 + pyrit/scenario/core/__init__.py | 6 + .../scenario/core/attack_technique_factory.py | 22 ++ pyrit/scenario/core/input_schema.py | 101 ++++++++ pyrit/scenario/core/scenario.py | 27 +++ pyrit/scenario/core/waterfall.py | 117 ++++++++++ tests/unit/scenario/test_input_schema.py | 99 ++++++++ tests/unit/scenario/test_waterfall.py | 218 ++++++++++++++++++ 9 files changed, 599 insertions(+) create mode 100644 pyrit/scenario/core/input_schema.py create mode 100644 pyrit/scenario/core/waterfall.py create mode 100644 tests/unit/scenario/test_input_schema.py create mode 100644 tests/unit/scenario/test_waterfall.py diff --git a/pyrit/registry/object_registries/attack_technique_registry.py b/pyrit/registry/object_registries/attack_technique_registry.py index 55e15a02e..e61eb60b5 100644 --- a/pyrit/registry/object_registries/attack_technique_registry.py +++ b/pyrit/registry/object_registries/attack_technique_registry.py @@ -272,6 +272,7 @@ def build_factory_from_spec( adversarial_config=adversarial_config, seed_technique=spec.seed_technique, scorer_override_policy=scorer_override_policy, + source_spec=spec, ) @staticmethod diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 7d3529ee6..cb113712f 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -24,6 +24,8 @@ BaselineAttackPolicy, DatasetConfiguration, PolicyAction, + RoleDescriptor, + RoleTag, Scenario, ScenarioCompositeStrategy, ScenarioCoreState, @@ -34,6 +36,8 @@ StrategyGraph, StrategyPolicy, linear_strategy_policy, + policy_to_spec, + spec_to_enum, ) # Import scenario submodules directly and register them as virtual subpackages @@ -66,6 +70,8 @@ "DatasetConfiguration", "Parameter", "PolicyAction", + "RoleDescriptor", + "RoleTag", "Scenario", "ScenarioCompositeStrategy", "ScenarioCoreState", @@ -83,4 +89,6 @@ "foundry", "garak", "linear_strategy_policy", + "policy_to_spec", + "spec_to_enum", ] diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 86f1433a4..063d5508c 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -8,6 +8,7 @@ from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory, ScorerOverridePolicy from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult @@ -18,6 +19,7 @@ register_scenario_techniques, ) from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy, linear_strategy_policy +from pyrit.scenario.core.waterfall import policy_to_spec, spec_to_enum __all__ = [ "AtomicAttack", @@ -27,6 +29,8 @@ "DatasetConfiguration", "EXPLICIT_SEED_GROUPS_KEY", "PolicyAction", + "RoleDescriptor", + "RoleTag", "SCENARIO_TECHNIQUES", "Parameter", "Scenario", @@ -40,7 +44,9 @@ "StrategyGraph", "StrategyPolicy", "linear_strategy_policy", + "policy_to_spec", "register_scenario_techniques", "get_default_scorer_target", "get_default_adversarial_target", + "spec_to_enum", ] diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index a080c1550..e1daa1dbc 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -30,6 +30,9 @@ ) from pyrit.models import SeedAttackTechniqueGroup from pyrit.prompt_target import PromptTarget + from pyrit.registry.object_registries.attack_technique_registry import ( + AttackTechniqueSpec, + ) logger = logging.getLogger(__name__) @@ -63,6 +66,7 @@ def __init__( adversarial_config: AttackAdversarialConfig | None = None, seed_technique: SeedAttackTechniqueGroup | None = None, scorer_override_policy: ScorerOverridePolicy = ScorerOverridePolicy.WARN, + source_spec: AttackTechniqueSpec | None = None, ) -> None: """ Initialize the factory with a technique-specific configuration. @@ -80,6 +84,10 @@ def __init__( seed_technique: Optional technique seed group to attach to created techniques. scorer_override_policy: What to do when a scenario's scorer is incompatible with the attack's ``attack_scoring_config`` type annotation. Defaults to WARN. + source_spec: Optional ``AttackTechniqueSpec`` this factory was built from. + Set by :meth:`AttackTechniqueRegistry.build_factory_from_spec` and used + by the Phase 8 waterfall to recover the declarative spec layer. + ``None`` for factories constructed directly. Raises: TypeError: If any kwarg name is not a valid constructor parameter, @@ -92,6 +100,7 @@ def __init__( self._adversarial_config = adversarial_config self._seed_technique = seed_technique self._scorer_override_policy = scorer_override_policy + self._source_spec = source_spec self._validate_kwargs() @@ -160,6 +169,19 @@ def adversarial_chat(self) -> PromptTarget | None: """The adversarial chat target baked into this factory, or None.""" return self._adversarial_config.target if self._adversarial_config else None + @property + def source_spec(self) -> AttackTechniqueSpec | None: + """ + The ``AttackTechniqueSpec`` this factory was built from, if any. + + Set by :meth:`AttackTechniqueRegistry.build_factory_from_spec`. ``None`` + for factories constructed directly via ``AttackTechniqueFactory(...)``. + + Used by the Phase 8 waterfall (``policy_to_spec``) to recover the + declarative spec layer from a configured scenario. + """ + return self._source_spec + def create( self, *, diff --git a/pyrit/scenario/core/input_schema.py b/pyrit/scenario/core/input_schema.py new file mode 100644 index 000000000..13504d2e3 --- /dev/null +++ b/pyrit/scenario/core/input_schema.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +RoleDescriptor — declarative schema for a scenario's ``__init__`` inputs. + +A scenario's ``__init__`` arguments fall into a small taxonomy: + +* **scalar** — primitive values (``int``, ``str``, ``bool``, ``float``). +* **choice** — a fixed set of literal alternatives. +* **registry_ref** — a name into a PyRIT registry (e.g. a target name). +* **factory** — a structured spec elicited via nested schema (``AttackTechniqueSpec``). +* **opaque** — a pre-built ``Identifiable`` instance the CLI wizard cannot + elicit (``AtomicAttack``, ``OutcomeScorer``, etc.). Programmatic + callers can supply these directly; CLI flows must use an artifact. + +``Scenario.input_schema()`` returns ``list[RoleDescriptor]`` declaring the rich-object +``__init__`` inputs; ``Scenario.supported_parameters()`` already declares the scalar +``initialize_async`` arguments. The two are intentionally orthogonal. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import GenericAlias + + +class RoleTag(str, Enum): + """Taxonomy of input roles the wizard / waterfall can elicit.""" + + SCALAR = "scalar" + CHOICE = "choice" + REGISTRY_REF = "registry_ref" + FACTORY = "factory" + OPAQUE = "opaque" + + +@dataclass(frozen=True) +class RoleDescriptor: + """ + Describes one ``__init__`` input of a scenario. + + Args: + name (str): Argument name on the scenario's ``__init__``. Must be a + valid Python identifier. + description (str): Human-readable description shown in wizard prompts + and ``--help`` output. + tag (RoleTag): The elicitation strategy class. See module docstring. + param_type (type | GenericAlias | None): Optional declared type for + coercion / validation. ``None`` skips type enforcement; callers are + then responsible for type correctness. + choices (tuple[Any, ...] | None): For :attr:`RoleTag.CHOICE` roles, + the allowed alternatives. ``None`` for other tags. + default (Any): Default value when the caller does not supply one. May + be ``None`` even for required roles (the wizard distinguishes via + ``required``). + required (bool): When True, the wizard must elicit a value; when False, + ``default`` is used if no value is supplied. + """ + + name: str + description: str + tag: RoleTag + param_type: type | GenericAlias | None = None + choices: tuple[Any, ...] | None = None + default: Any = None + required: bool = True + + def __post_init__(self) -> None: + """ + Validate structural invariants and normalize ``choices`` to a tuple. + + Raises: + ValueError: If ``name`` isn't a valid identifier, if a CHOICE-tagged + role has no choices, if a non-CHOICE role declares choices, or + if an OPAQUE role is optional without a default. + """ + if not self.name.isidentifier(): + raise ValueError(f"RoleDescriptor.name must be a valid Python identifier, got {self.name!r}.") + + if self.choices is not None and not isinstance(self.choices, tuple): + object.__setattr__(self, "choices", tuple(self.choices)) + + if self.tag is RoleTag.CHOICE: + if not self.choices: + raise ValueError(f"RoleDescriptor '{self.name}' tagged CHOICE must declare non-empty choices.") + elif self.choices is not None: + raise ValueError( + f"RoleDescriptor '{self.name}' tagged {self.tag.value} must not declare choices " + "(choices are only valid for CHOICE-tagged roles)." + ) + + if self.tag is RoleTag.OPAQUE and self.required is False and self.default is None: + raise ValueError( + f"RoleDescriptor '{self.name}' tagged OPAQUE cannot be optional without a default — " + "opaque roles cannot be elicited from the CLI, so a non-None default is required." + ) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index e50a061a2..3fa0fa1cc 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -36,6 +36,7 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.input_schema import RoleDescriptor from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_scorer_target @@ -363,6 +364,32 @@ def supported_parameters(cls) -> list[Parameter]: """ return [] + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + """ + Override to declare rich-object ``__init__`` inputs the wizard should elicit. + + Returns a ``list[RoleDescriptor]`` describing arguments that + :meth:`__init__` accepts beyond the standard scenario plumbing + (``scenario_result_id``, ``params``, ``memory_labels``). Each descriptor + carries a :class:`RoleTag` declaring how the role is elicited + (scalar, choice, registry reference, factory spec, or opaque instance). + + This is intentionally orthogonal to :meth:`supported_parameters`: + + * :meth:`supported_parameters` declares **scalar** arguments to + :meth:`initialize_async` (CLI ``--kebab-flag`` surface, unchanged). + * :meth:`input_schema` declares **rich-object** arguments to + :meth:`__init__` (wizard / programmatic surface). + + Default returns ``[]``; most scenarios accept no rich-object inputs + beyond the standard plumbing. + + Returns: + list[RoleDescriptor]: Declared roles (default: empty list). + """ + return [] + def _get_attack_technique_factories(self) -> dict[str, "AttackTechniqueFactory"]: """ Return the attack technique factories for this scenario. diff --git a/pyrit/scenario/core/waterfall.py b/pyrit/scenario/core/waterfall.py new file mode 100644 index 000000000..38da62b47 --- /dev/null +++ b/pyrit/scenario/core/waterfall.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Phase 8 waterfall — translation between scenario configuration layers. + +Three layers, two forward translations (this module): + +* **policy** — a configured :class:`Scenario` with a ready + :class:`StrategyGraph` (the executable artifact). +* **spec** — the technique catalog the policy uses + (``list[AttackTechniqueSpec]``). +* **strategy enum** — the public ``ScenarioStrategy`` enum members that + represent those techniques (the CLI / wizard surface). + +The forward direction is lossy but well-defined: + +* ``policy_to_spec`` extracts the spec layer. Returns ``[]`` when the scenario + does not use the technique registry pattern (e.g. policy-parameterized + scenarios like ``BroadSweepThenDeepDive``). +* ``spec_to_enum`` resolves specs to ``ScenarioStrategy`` members of the + scenario's strategy class. Returns ``None`` when no member matches. + +The inverse direction (``enum_to_spec`` + ``spec_to_policy_inputs``) lands in +Phase 8e; both are partial and best-effort. Bugs in any direction stay +localized to this file. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyrit.registry.object_registries.attack_technique_registry import ( + AttackTechniqueSpec, + ) + from pyrit.scenario.core.scenario import Scenario + from pyrit.scenario.core.scenario_strategy import ScenarioStrategy + + +def policy_to_spec(scenario: Scenario) -> list[AttackTechniqueSpec]: + """ + Extract the technique catalog (spec layer) backing a configured scenario. + + Walks the scenario's selected strategies and the factories it returns from + ``_get_attack_technique_factories``, returning each factory's + :attr:`AttackTechniqueFactory.source_spec`. Factories without a source + spec (constructed directly rather than via + ``AttackTechniqueRegistry.build_factory_from_spec``) are skipped silently. + + Args: + scenario (Scenario): An initialized scenario (``initialize_async`` has + been called, so ``_scenario_strategies`` is populated). + + Returns: + list[AttackTechniqueSpec]: One spec per selected technique that has a + recoverable source spec. Empty list when the scenario uses no + technique registry (policy-parameterized scenarios like + ``BroadSweepThenDeepDive``) or when no factories carry source specs. + """ + selected = getattr(scenario, "_scenario_strategies", None) + if not selected: + return [] + + try: + factories = scenario._get_attack_technique_factories() + except (AttributeError, NotImplementedError): + return [] + + specs: list[AttackTechniqueSpec] = [] + for strategy in selected: + factory = factories.get(strategy.value) + if factory is None or factory.source_spec is None: + continue + specs.append(factory.source_spec) + return specs + + +def spec_to_enum( + scenario_cls: type[Scenario], + specs: list[AttackTechniqueSpec], +) -> list[ScenarioStrategy] | None: + """ + Resolve a spec list to ``ScenarioStrategy`` members of a scenario class. + + Matches by :attr:`AttackTechniqueSpec.name` against + ``scenario_cls.get_strategy_class()`` members. Returns ``None`` when any + spec cannot be resolved (e.g. the spec name is not a member of the + scenario's strategy class) — the caller is expected to fall back to + ``--from-artifact`` (Phase 8e). + + Args: + scenario_cls (type[Scenario]): The scenario class whose strategy enum + should be inspected. + specs (list[AttackTechniqueSpec]): Specs to resolve. An empty list + returns ``[]`` (vacuously valid). + + Returns: + list[ScenarioStrategy] | None: Matching enum members in input order, + or ``None`` when at least one spec cannot be resolved. + """ + if not specs: + return [] + + try: + strategy_cls = scenario_cls.get_strategy_class() + except (AttributeError, NotImplementedError): + return None + + by_value = {member.value: member for member in strategy_cls} + resolved: list[ScenarioStrategy] = [] + for spec in specs: + member = by_value.get(spec.name) + if member is None: + return None + resolved.append(member) + return resolved diff --git a/tests/unit/scenario/test_input_schema.py b/tests/unit/scenario/test_input_schema.py new file mode 100644 index 000000000..85a7eca37 --- /dev/null +++ b/tests/unit/scenario/test_input_schema.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Phase 8a — coverage for ``RoleDescriptor`` + ``RoleTag``.""" + +from dataclasses import FrozenInstanceError + +import pytest + +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + + +class TestRoleTag: + def test_role_tag_is_str_enum(self): + """``RoleTag`` is a ``(str, Enum)`` so values are JSON-serializable.""" + assert isinstance(RoleTag.SCALAR.value, str) + assert RoleTag.SCALAR == "scalar" + + def test_role_tag_members_complete(self): + """The five canonical tags are present and distinct.""" + names = {tag.name for tag in RoleTag} + assert names == {"SCALAR", "CHOICE", "REGISTRY_REF", "FACTORY", "OPAQUE"} + + +class TestRoleDescriptorConstruction: + def test_minimal_scalar_role(self): + role = RoleDescriptor(name="weakness_label", description="Label", tag=RoleTag.SCALAR) + assert role.name == "weakness_label" + assert role.tag is RoleTag.SCALAR + assert role.required is True + assert role.default is None + assert role.choices is None + + def test_frozen_instance(self): + role = RoleDescriptor(name="x", description="d", tag=RoleTag.SCALAR) + with pytest.raises(FrozenInstanceError): + role.name = "y" # type: ignore[misc] + + def test_choice_role_with_choices(self): + role = RoleDescriptor( + name="mode", + description="Operating mode", + tag=RoleTag.CHOICE, + choices=("fast", "slow"), + param_type=str, + ) + assert role.choices == ("fast", "slow") + + def test_choices_normalized_to_tuple(self): + role = RoleDescriptor( + name="mode", + description="d", + tag=RoleTag.CHOICE, + choices=["a", "b"], # type: ignore[arg-type] + ) + assert role.choices == ("a", "b") + assert isinstance(role.choices, tuple) + + +class TestRoleDescriptorValidation: + def test_name_must_be_identifier(self): + with pytest.raises(ValueError, match="valid Python identifier"): + RoleDescriptor(name="not-a-name", description="d", tag=RoleTag.SCALAR) + + def test_name_with_space_rejected(self): + with pytest.raises(ValueError, match="valid Python identifier"): + RoleDescriptor(name="bad name", description="d", tag=RoleTag.SCALAR) + + def test_choice_without_choices_rejected(self): + with pytest.raises(ValueError, match="must declare non-empty choices"): + RoleDescriptor(name="mode", description="d", tag=RoleTag.CHOICE) + + def test_choice_with_empty_choices_rejected(self): + with pytest.raises(ValueError, match="must declare non-empty choices"): + RoleDescriptor(name="mode", description="d", tag=RoleTag.CHOICE, choices=()) + + def test_non_choice_with_choices_rejected(self): + with pytest.raises(ValueError, match="must not declare choices"): + RoleDescriptor(name="x", description="d", tag=RoleTag.SCALAR, choices=("a",)) + + def test_opaque_optional_without_default_rejected(self): + with pytest.raises(ValueError, match="opaque roles cannot be elicited"): + RoleDescriptor(name="sweep", description="d", tag=RoleTag.OPAQUE, required=False) + + def test_opaque_optional_with_default_allowed(self): + sentinel = object() + role = RoleDescriptor( + name="sweep", + description="d", + tag=RoleTag.OPAQUE, + required=False, + default=sentinel, + ) + assert role.default is sentinel + + def test_opaque_required_allowed(self): + role = RoleDescriptor(name="sweep", description="d", tag=RoleTag.OPAQUE, required=True) + assert role.required is True + assert role.default is None diff --git a/tests/unit/scenario/test_waterfall.py b/tests/unit/scenario/test_waterfall.py new file mode 100644 index 000000000..425826660 --- /dev/null +++ b/tests/unit/scenario/test_waterfall.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Phase 8a — coverage for the forward waterfall in ``pyrit.scenario.core.waterfall``. + +Pins: +- ``policy_to_spec`` extracts ``AttackTechniqueSpec`` instances from a + configured scenario via the factory's ``source_spec`` attribute. +- ``spec_to_enum`` resolves specs to ``ScenarioStrategy`` members of a + scenario class, returning ``None`` when any spec doesn't match. +- Both functions degrade gracefully (return ``[]`` or ``None``) for + scenarios that don't use the technique registry pattern. +""" + +from __future__ import annotations + +from typing import ClassVar, cast +from unittest.mock import MagicMock + +import pytest + +from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack +from pyrit.identifiers import ComponentIdentifier +from pyrit.registry.object_registries.attack_technique_registry import ( + AttackTechniqueRegistry, + AttackTechniqueSpec, +) +from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario +from pyrit.scenario.core.scenario_strategy import ScenarioStrategy +from pyrit.scenario.core.waterfall import policy_to_spec, spec_to_enum +from pyrit.score import Scorer + +# ---------- helpers ---------------------------------------------------------- + + +class _DummyStrategy(ScenarioStrategy): + ALPHA = ("alpha", set()) + BETA = ("beta", set()) + ALL = ("all", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + return {"all"} + + +def _make_scorer_mock() -> Scorer: + mock_scorer = MagicMock(spec=Scorer) + mock_scorer.get_identifier.return_value = ComponentIdentifier( + class_name="MockScorer", class_module="tests.unit.scenario" + ) + mock_scorer.get_scorer_metrics.return_value = None + return mock_scorer + + +class _DummyScenario(Scenario): + """Bare Scenario subclass; tests poke ``_scenario_strategies`` directly.""" + + BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden + + def __init__(self, factories: dict[str, AttackTechniqueFactory] | None = None, **kwargs): + kwargs.setdefault("strategy_class", _DummyStrategy) + kwargs.setdefault("objective_scorer", _make_scorer_mock()) + super().__init__(**kwargs) + self._factories_override = factories or {} + self._scenario_strategies = [] + + @classmethod + def get_strategy_class(cls) -> type[ScenarioStrategy]: + return _DummyStrategy + + @classmethod + def get_default_strategy(cls) -> ScenarioStrategy: + return _DummyStrategy.ALL + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + return DatasetConfiguration() + + def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: + return self._factories_override + + async def _get_atomic_attacks_async(self): + return [] + + +def _spec(name: str) -> AttackTechniqueSpec: + return AttackTechniqueSpec(name=name, attack_class=PromptSendingAttack) + + +def _factory_from_spec(spec: AttackTechniqueSpec) -> AttackTechniqueFactory: + return AttackTechniqueRegistry.build_factory_from_spec(spec) + + +def _factory_without_spec() -> AttackTechniqueFactory: + return AttackTechniqueFactory(attack_class=PromptSendingAttack) + + +# ---------- policy_to_spec --------------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +class TestPolicyToSpec: + def test_returns_empty_for_unset_strategies(self): + scenario = _DummyScenario(name="t", version=1) + scenario._scenario_strategies = [] + assert policy_to_spec(scenario) == [] + + def test_returns_empty_when_get_factories_raises_not_implemented(self): + scenario = _DummyScenario(name="t", version=1) + scenario._scenario_strategies = [_DummyStrategy.ALPHA] + + def _no_factories() -> dict[str, AttackTechniqueFactory]: + raise NotImplementedError("policy-parameterized scenario") + + scenario._get_attack_technique_factories = _no_factories # type: ignore[assignment] + assert policy_to_spec(scenario) == [] + + def test_extracts_spec_for_selected_strategy(self): + spec = _spec("alpha") + scenario = _DummyScenario(name="t", version=1, factories={"alpha": _factory_from_spec(spec)}) + scenario._scenario_strategies = [_DummyStrategy.ALPHA] + result = policy_to_spec(scenario) + assert result == [spec] + + def test_preserves_strategy_order(self): + alpha, beta = _spec("alpha"), _spec("beta") + scenario = _DummyScenario( + name="t", + version=1, + factories={"alpha": _factory_from_spec(alpha), "beta": _factory_from_spec(beta)}, + ) + scenario._scenario_strategies = [_DummyStrategy.BETA, _DummyStrategy.ALPHA] + result = policy_to_spec(scenario) + assert result == [beta, alpha] + + def test_skips_factory_without_source_spec(self): + alpha = _spec("alpha") + scenario = _DummyScenario( + name="t", + version=1, + factories={ + "alpha": _factory_from_spec(alpha), + "beta": _factory_without_spec(), + }, + ) + scenario._scenario_strategies = [_DummyStrategy.ALPHA, _DummyStrategy.BETA] + # Beta has no source spec → silently skipped + assert policy_to_spec(scenario) == [alpha] + + def test_skips_strategy_with_missing_factory(self): + alpha = _spec("alpha") + scenario = _DummyScenario(name="t", version=1, factories={"alpha": _factory_from_spec(alpha)}) + scenario._scenario_strategies = [_DummyStrategy.ALPHA, _DummyStrategy.BETA] + assert policy_to_spec(scenario) == [alpha] + + +# ---------- spec_to_enum ----------------------------------------------------- + + +class TestSpecToEnum: + def test_empty_specs_returns_empty(self): + assert spec_to_enum(_DummyScenario, []) == [] + + def test_resolves_single_spec(self): + result = spec_to_enum(_DummyScenario, [_spec("alpha")]) + assert result == [_DummyStrategy.ALPHA] + + def test_preserves_input_order(self): + result = spec_to_enum(_DummyScenario, [_spec("beta"), _spec("alpha")]) + assert result == [_DummyStrategy.BETA, _DummyStrategy.ALPHA] + + def test_unknown_spec_returns_none(self): + # 'gamma' is not a member of _DummyStrategy + result = spec_to_enum(_DummyScenario, [_spec("alpha"), _spec("gamma")]) + assert result is None + + def test_returns_none_when_strategy_class_not_implemented(self): + class _NoStrategyScenario(_DummyScenario): + @classmethod + def get_strategy_class(cls) -> type[ScenarioStrategy]: + raise NotImplementedError("policy-parameterized") + + cls = cast("type[Scenario]", _NoStrategyScenario) + assert spec_to_enum(cls, [_spec("alpha")]) is None + + +# ---------- factory.source_spec round-trip ----------------------------------- + + +class TestFactorySourceSpec: + def test_factory_built_from_spec_carries_source_spec(self): + spec = _spec("alpha") + factory = AttackTechniqueRegistry.build_factory_from_spec(spec) + assert factory.source_spec is spec + + def test_factory_constructed_directly_has_no_source_spec(self): + factory = AttackTechniqueFactory(attack_class=PromptSendingAttack) + assert factory.source_spec is None + + def test_explicit_source_spec_kwarg_is_honored(self): + spec = _spec("alpha") + factory = AttackTechniqueFactory(attack_class=PromptSendingAttack, source_spec=spec) + assert factory.source_spec is spec + + +@pytest.mark.usefixtures("patch_central_database") +@pytest.mark.parametrize("name", ["alpha", "beta"]) +def test_round_trip_policy_to_spec_to_enum(name: str): + """A scenario with one strategy round-trips back to that strategy.""" + spec = _spec(name) + scenario = _DummyScenario(name="t", version=1, factories={name: _factory_from_spec(spec)}) + scenario._scenario_strategies = [_DummyStrategy[name.upper()]] + specs = policy_to_spec(scenario) + enums = spec_to_enum(_DummyScenario, specs) + assert enums == [_DummyStrategy[name.upper()]] From 228ecad916bccf1650a87462fe897b398557fb40 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 15:01:58 -0700 Subject: [PATCH 18/42] declare BroadSweepThenDeepDive baseline policy as Forbidden Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py | 9 ++++++++- .../scenario/scenarios/airt/test_sweep_then_deep_dive.py | 8 ++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index 9a23cf1c4..515cba0ab 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -40,7 +40,7 @@ from pyrit.common import apply_defaults from pyrit.models import Message from pyrit.scenario.core.dataset_configuration import DatasetConfiguration -from pyrit.scenario.core.scenario import Scenario +from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.scenario.core.strategy_graph import ( @@ -435,6 +435,13 @@ class BroadSweepThenDeepDive(Scenario): VERSION: int = 1 + #: The branching graph in ``_build_execution_graph`` ignores the orchestrator-supplied + #: ``steps`` list and dispatches a fixed sweep -> deep-dive policy, so an implicitly + #: prepended baseline ``AtomicAttack`` would be persisted in ``_atomic_attacks`` but + #: never executed by the graph. ``Forbidden`` makes the (correct) intent explicit and + #: fails fast if a caller passes ``include_baseline=True``. + BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden + @apply_defaults def __init__( self, diff --git a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py index 8e3ec23e6..becd8ee9c 100644 --- a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py +++ b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py @@ -30,6 +30,7 @@ from pyrit.models import AttackOutcome, AttackResult, MessagePiece from pyrit.scenario import DatasetConfiguration from pyrit.scenario.core import AtomicAttack +from pyrit.scenario.core.scenario import BaselineAttackPolicy from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( @@ -508,6 +509,13 @@ def test_strategy_metadata(self) -> None: assert BroadSweepThenDeepDive.get_default_strategy() is BroadSweepThenDeepDiveStrategy.DEFAULT assert isinstance(BroadSweepThenDeepDive.default_dataset_config(), DatasetConfiguration) + def test_baseline_attack_policy_is_forbidden(self) -> None: + # The branching graph ignores the orchestrator-supplied ``steps`` list, so an + # implicitly prepended baseline ``AtomicAttack`` would be persisted but never + # executed. ``Forbidden`` makes the (correct) intent explicit and fails fast + # if a caller passes ``include_baseline=True``. + assert BroadSweepThenDeepDive.BASELINE_ATTACK_POLICY is BaselineAttackPolicy.Forbidden + async def test_get_atomic_attacks_returns_canonical_order(self) -> None: scenario, sweep_atomic, deep_dives = self._build_scenario( sweep_response_text="safe", From 1bd09dd8a6eb1802c45046bff7ce36d735465946 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:11:40 -0700 Subject: [PATCH 19/42] FEAT: Phase 8b per-scenario input schemas + 8g prerequisites --- .../scenarios/adaptive/adaptive_scenario.py | 71 ++++++++- .../scenarios/airt/sweep_then_deep_dive.py | 67 ++++++++ .../adaptive/test_adaptive_input_schema.py | 146 ++++++++++++++++++ .../scenarios/adaptive/test_text_adaptive.py | 33 ++++ .../test_sweep_then_deep_dive_input_schema.py | 86 +++++++++++ 5 files changed, 399 insertions(+), 4 deletions(-) create mode 100644 tests/unit/scenario/scenarios/adaptive/test_adaptive_input_schema.py create mode 100644 tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py index b9b2bf87e..cc8729153 100644 --- a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -16,12 +16,13 @@ from __future__ import annotations +import hashlib import logging import random -import uuid from typing import TYPE_CHECKING, ClassVar, cast from pyrit.executor.attack import AttackScoringConfig +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy @@ -110,6 +111,63 @@ def __init__( scenario_result_id=scenario_result_id, ) + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + """ + Declare the wizard-elicitable scalar inputs to ``__init__``. + + Only the four numeric / seed scalars are returned here. ``objective_scorer`` + and ``scenario_result_id`` are base-scenario lifecycle arguments handled by + the wizard's standard plumbing; ``context_extractor`` is a ``Callable`` + with a usable default (``global_context``) that the CLI elicitor cannot + introspect — programmatic callers may override it directly via + ``build_scenario_from_inputs(..., init_inputs={"context_extractor": ...})``. + Strategy selection (``scenario_strategies``) lives in + :meth:`supported_parameters` as an :meth:`initialize_async` ``--kebab-flag`` + argument, not in :meth:`__init__`. + + Returns: + list[RoleDescriptor]: Four SCALAR roles mirroring the constructor + defaults. + """ + return [ + RoleDescriptor( + name="epsilon", + description="Exploration probability for the epsilon-greedy selector (0.0 = pure exploit).", + tag=RoleTag.SCALAR, + param_type=float, + default=0.2, + required=False, + ), + RoleDescriptor( + name="pool_threshold", + description=( + "Minimum per-(context, technique) attempts before the local estimate overrides the pooled rate. " + "Set to 1 to disable pooling." + ), + tag=RoleTag.SCALAR, + param_type=int, + default=3, + required=False, + ), + RoleDescriptor( + name="max_attempts_per_objective", + description="Maximum number of techniques to dispatch per objective before giving up.", + tag=RoleTag.SCALAR, + param_type=int, + default=3, + required=False, + ), + RoleDescriptor( + name="seed", + description="RNG seed for deterministic technique selection. ``None`` uses a non-deterministic RNG.", + tag=RoleTag.SCALAR, + param_type=int, + default=None, + required=False, + ), + ] + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ Build one :class:`AdaptiveStep` per objective. @@ -260,9 +318,14 @@ def _build_step_for_seed_group( return None adaptive_context = self._context_extractor(seed_group) - # Prefer the objective's id when available so resume keys stay stable - # across re-fetches of the same seed groups. - objective_id = seed_group.objective.id if seed_group.objective.id else uuid.uuid4() + # Derive a deterministic 12-char hash from the objective text so two + # ``initialize_async`` runs over structurally identical seed groups + # produce identical step identifiers — a Phase 8g prerequisite for + # graph-artifact round-trip hash equivalence. ``SeedObjective.id`` + # has a ``uuid.uuid4()`` default-factory (``seed.py:95``) that mints + # a fresh UUID per in-memory construction, so it cannot be used as + # a stable resume key here. + objective_id = hashlib.sha256(seed_group.objective.value.encode("utf-8")).hexdigest()[:12] atomic_attack_name = f"{self._atomic_attack_prefix}_{dataset_name}_{objective_id}" memory_labels = { diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index 515cba0ab..ee463e5d6 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -40,6 +40,7 @@ from pyrit.common import apply_defaults from pyrit.models import Message from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult from pyrit.scenario.core.scenario_strategy import ScenarioStrategy @@ -119,6 +120,12 @@ class CategoryAggregatingSweepStep(ScenarioStep): _OUTPUTS: ClassVar[tuple[str, ...]] = ("found_weaknesses", "all_safe") + #: Marker for ``graph_artifact.build_graph_artifact`` (Phase 8g): this step's + #: constructor takes a bound ``OutcomeScorer`` and an ``AtomicAttack`` whose + #: factory closures cannot be re-derived from primitive args. Encoding via + #: ``ComponentIdentifier.to_dict()`` is the only sound round-trip path. + GRAPH_ARTIFACT_OPAQUE: ClassVar[bool] = True + def __init__( self, *, @@ -292,6 +299,12 @@ class FilteredDeepDiveStep(ScenarioStep): _OUTPUTS: ClassVar[tuple[str, ...]] = ("done",) + #: Marker for ``graph_artifact.build_graph_artifact`` (Phase 8g): this step's + #: constructor takes a ``Callable`` closure (``weak_categories_ref``) that + #: cannot be re-derived from primitive args. Encoding via + #: ``ComponentIdentifier.to_dict()`` is the only sound round-trip path. + GRAPH_ARTIFACT_OPAQUE: ClassVar[bool] = True + def __init__( self, *, @@ -502,6 +515,60 @@ def __init__( scenario_result_id=scenario_result_id, ) + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + """ + Declare the rich-object and scalar inputs the wizard / artifact must capture. + + The three opaque roles are pre-built ``Identifiable`` instances that the + CLI wizard cannot elicit directly — programmatic callers must supply them + and CLI flows must round-trip them through a saved graph artifact (see + :class:`pyrit.scenario.core.graph_artifact.GraphArtifact`). The single + scalar role (``weakness_label``) is freely elicitable. + + Returns: + list[RoleDescriptor]: Three OPAQUE roles plus one SCALAR. + """ + return [ + RoleDescriptor( + name="sweep_atomic_attack", + description=( + "Pre-built single-turn AtomicAttack run across all seed groups during the sweep phase. " + "Must already be wired to a target and scorer." + ), + tag=RoleTag.OPAQUE, + required=True, + ), + RoleDescriptor( + name="deep_dive_atomic_attacks", + description=( + "Pre-built sequence of multi-turn AtomicAttacks considered during the deep-dive phase. " + "Each is gated by its ``display_group`` matching a category flagged by the sweep." + ), + tag=RoleTag.OPAQUE, + required=True, + ), + RoleDescriptor( + name="outcome_scorer", + description=( + "Pre-built OutcomeScorer whose per-response label set must include ``weakness_label``. " + "Drives the sweep's category-weakness classification." + ), + tag=RoleTag.OPAQUE, + required=True, + ), + RoleDescriptor( + name="weakness_label", + description=( + "OutcomeScorer label that signals a category breach and triggers escalation to deep dive." + ), + tag=RoleTag.SCALAR, + param_type=str, + default="safety_violation", + required=False, + ), + ] + @classmethod def get_strategy_class(cls) -> type[ScenarioStrategy]: """Return the (single-member) strategy enum class.""" diff --git a/tests/unit/scenario/scenarios/adaptive/test_adaptive_input_schema.py b/tests/unit/scenario/scenarios/adaptive/test_adaptive_input_schema.py new file mode 100644 index 000000000..d1419c95c --- /dev/null +++ b/tests/unit/scenario/scenarios/adaptive/test_adaptive_input_schema.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Phase 8b — coverage for ``AdaptiveScenario.input_schema()`` + UUID determinism.""" + +from __future__ import annotations + +import hashlib + +import pytest + +from pyrit.models import SeedAttackGroup, SeedObjective +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag +from pyrit.scenario.scenarios.adaptive.adaptive_scenario import AdaptiveScenario + + +def _make_seed_group(*, value: str, harm_categories: list[str] | None = None) -> SeedAttackGroup: + return SeedAttackGroup(seeds=[SeedObjective(value=value, harm_categories=harm_categories)]) + + +class TestAdaptiveInputSchema: + """``AdaptiveScenario.input_schema()`` declares the wizard-elicitable scalars.""" + + def test_returns_four_roles(self): + schema = AdaptiveScenario.input_schema() + assert len(schema) == 4 + assert all(isinstance(role, RoleDescriptor) for role in schema) + + def test_role_names_match_constructor_scalars(self): + names = [role.name for role in AdaptiveScenario.input_schema()] + assert names == ["epsilon", "pool_threshold", "max_attempts_per_objective", "seed"] + + def test_all_roles_are_scalar(self): + schema = AdaptiveScenario.input_schema() + assert all(role.tag is RoleTag.SCALAR for role in schema) + + def test_defaults_mirror_constructor_signature(self): + by_name = {role.name: role for role in AdaptiveScenario.input_schema()} + assert by_name["epsilon"].default == 0.2 + assert by_name["pool_threshold"].default == 3 + assert by_name["max_attempts_per_objective"].default == 3 + assert by_name["seed"].default is None + + def test_param_types_match_constructor_annotations(self): + by_name = {role.name: role for role in AdaptiveScenario.input_schema()} + assert by_name["epsilon"].param_type is float + assert by_name["pool_threshold"].param_type is int + assert by_name["max_attempts_per_objective"].param_type is int + assert by_name["seed"].param_type is int + + def test_all_roles_are_optional(self): + """Every constructor scalar has a default; none are wizard-required.""" + schema = AdaptiveScenario.input_schema() + assert all(role.required is False for role in schema) + + def test_no_role_for_objective_scorer(self): + """``objective_scorer`` is base-lifecycle plumbing, handled outside input_schema().""" + names = {role.name for role in AdaptiveScenario.input_schema()} + assert "objective_scorer" not in names + + def test_no_role_for_context_extractor(self): + """``context_extractor`` is a Callable; CLI cannot elicit it. Programmatic-only override.""" + names = {role.name for role in AdaptiveScenario.input_schema()} + assert "context_extractor" not in names + + def test_no_role_for_scenario_strategies(self): + """``scenario_strategies`` lives in ``supported_parameters()``, not input_schema().""" + names = {role.name for role in AdaptiveScenario.input_schema()} + assert "scenario_strategies" not in names + + +class TestObjectiveIdIsNonDeterministic: + """8b-1 motivating invariant: ``SeedObjective.id`` auto-mints fresh UUIDs. + + This is why Phase 8b switched to content-hashing instead of trusting + ``seed_group.objective.id`` for the atomic_attack_name suffix. + """ + + def test_seed_objective_id_default_is_fresh_uuid_per_construction(self): + """``SeedObjective(value=...)`` mints a new UUID each call (see ``Seed.id`` default_factory). + + This is the root cause of the non-determinism that Phase 8b fixed: + relying on ``objective.id`` for resume keys breaks across in-memory re-construction. + """ + a = SeedObjective(value="cause harm") + b = SeedObjective(value="cause harm") + # Both have non-falsy ids (UUID4), but those ids are different. + assert a.id + assert b.id + assert a.id != b.id + + def test_sha256_content_hash_is_stable_across_constructions(self): + """The Phase 8b replacement (content hash) is deterministic by construction.""" + a = SeedObjective(value="cause harm") + b = SeedObjective(value="cause harm") + ha = hashlib.sha256(a.value.encode("utf-8")).hexdigest()[:12] + hb = hashlib.sha256(b.value.encode("utf-8")).hexdigest()[:12] + assert ha == hb + assert len(ha) == 12 + + def test_sha256_content_hash_differs_for_different_inputs(self): + h1 = hashlib.sha256(b"obj-1").hexdigest()[:12] + h2 = hashlib.sha256(b"obj-2").hexdigest()[:12] + assert h1 != h2 + + def test_uuid_module_no_longer_imported_in_adaptive_scenario(self): + """``uuid`` is no longer imported by adaptive_scenario after Phase 8b. + + Phase 8b removed the ``uuid.uuid4()`` fallback and the + ``seed_group.objective.id`` preference in favor of a deterministic + SHA256 hash of the objective text. Re-adding ``uuid`` would re-introduce + the non-determinism — this test guards against that regression by checking + the module namespace directly. + """ + import pyrit.scenario.scenarios.adaptive.adaptive_scenario as mod + + assert "uuid" not in mod.__dict__, ( + "uuid import was removed in Phase 8b in favor of deterministic SHA256 hashing; " + "if you re-added it for a new use, also re-evaluate the deterministic-step-name invariant." + ) + + +@pytest.fixture +def two_identical_seed_groups() -> tuple[SeedAttackGroup, SeedAttackGroup]: + """Two seed groups with identical objective text. Each gets its own ``objective.id`` UUID.""" + return ( + _make_seed_group(value="obj-deterministic", harm_categories=["violence"]), + _make_seed_group(value="obj-deterministic", harm_categories=["violence"]), + ) + + +class TestContentHashRegressionSurface: + """Pin the documented intent of the 8b-1 source change at a behavioral level. + + Note: a full ``_build_step_for_seed_group`` integration regression lives in + ``test_text_adaptive.py::TestTextAdaptiveAtomicAttacks::test_atomic_names_are_deterministic_across_runs``. + This class pins the lower-level invariants the integration test depends on. + """ + + def test_same_objective_value_yields_same_hash(self, two_identical_seed_groups): + a, b = two_identical_seed_groups + # Their auto-assigned ids differ (UUID4) — the hash is what makes resume stable. + assert a.objective.id != b.objective.id + hash_a = hashlib.sha256(a.objective.value.encode("utf-8")).hexdigest()[:12] + hash_b = hashlib.sha256(b.objective.value.encode("utf-8")).hexdigest()[:12] + assert hash_a == hash_b diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index bea5dba6c..5d97b2817 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -251,6 +251,39 @@ async def test_atomic_names_are_unique(self, mock_objective_target, mock_objecti names = [atomic.atomic_attack_name for atomic in attacks] assert len(set(names)) == len(names) + async def test_atomic_names_are_deterministic_across_runs(self, mock_objective_target, mock_objective_scorer): + """Phase 8b-1 regression: SHA256 fallback for unset objective.id is deterministic. + + Building the scenario twice with structurally identical seed groups must + produce identical atomic_attack_names. With the previous ``uuid.uuid4()`` + fallback, the two runs would produce different names and graph-artifact + round-trip (Phase 8g) would fail its hash-equivalence invariant. + """ + groups_factory = lambda: { # noqa: E731 + "violence": [ + _make_seed_group(value="obj-determ-1", harm_categories=["violence"]), + _make_seed_group(value="obj-determ-2", harm_categories=["violence"]), + ], + "hate": [_make_seed_group(value="obj-determ-3", harm_categories=["hate"])], + } + _s1, attacks_first = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups_factory(), + ) + _s2, attacks_second = await self._build_scenario_and_attacks( + mock_objective_target=mock_objective_target, + mock_objective_scorer=mock_objective_scorer, + seed_groups=groups_factory(), + ) + names_first = sorted(atomic.atomic_attack_name for atomic in attacks_first) + names_second = sorted(atomic.atomic_attack_name for atomic in attacks_second) + assert names_first == names_second, ( + "Atomic-attack names must be deterministic across runs with structurally " + "identical seed groups. The Phase 8b SHA256 fallback for unset objective.id " + "was likely replaced with a non-deterministic primitive." + ) + async def test_display_group_is_dataset_name(self, mock_objective_target, mock_objective_scorer): groups = { "violence": [_make_seed_group(value="obj-v", harm_categories=["violence"])], diff --git a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py new file mode 100644 index 000000000..39da38ebd --- /dev/null +++ b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Phase 8b — coverage for ``BroadSweepThenDeepDive.input_schema()`` + opacity flags.""" + +from __future__ import annotations + +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag +from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + BroadSweepThenDeepDive, + CategoryAggregatingSweepStep, + FilteredDeepDiveStep, +) + + +class TestBroadSweepThenDeepDiveInputSchema: + """``BroadSweepThenDeepDive.input_schema()`` declares 3 OPAQUE roles + 1 SCALAR.""" + + def test_returns_four_roles(self): + schema = BroadSweepThenDeepDive.input_schema() + assert len(schema) == 4 + assert all(isinstance(role, RoleDescriptor) for role in schema) + + def test_role_names_match_constructor_inputs(self): + names = [role.name for role in BroadSweepThenDeepDive.input_schema()] + assert names == [ + "sweep_atomic_attack", + "deep_dive_atomic_attacks", + "outcome_scorer", + "weakness_label", + ] + + def test_three_opaque_one_scalar(self): + by_tag: dict[RoleTag, list[str]] = {tag: [] for tag in RoleTag} + for role in BroadSweepThenDeepDive.input_schema(): + by_tag[role.tag].append(role.name) + assert sorted(by_tag[RoleTag.OPAQUE]) == sorted( + ["sweep_atomic_attack", "deep_dive_atomic_attacks", "outcome_scorer"] + ) + assert by_tag[RoleTag.SCALAR] == ["weakness_label"] + # No other tags present. + for tag, names in by_tag.items(): + if tag not in {RoleTag.OPAQUE, RoleTag.SCALAR}: + assert names == [] + + def test_opaque_roles_are_required(self): + """Each pre-built ``Identifiable`` instance is non-optional.""" + opaque_roles = [role for role in BroadSweepThenDeepDive.input_schema() if role.tag is RoleTag.OPAQUE] + assert len(opaque_roles) == 3 + assert all(role.required is True for role in opaque_roles) + + def test_weakness_label_defaults_match_constructor(self): + by_name = {role.name: role for role in BroadSweepThenDeepDive.input_schema()} + weakness = by_name["weakness_label"] + assert weakness.default == "safety_violation" + assert weakness.required is False + assert weakness.param_type is str + + def test_no_role_for_objective_scorer(self): + """``objective_scorer`` is base-lifecycle plumbing, handled outside input_schema().""" + names = {role.name for role in BroadSweepThenDeepDive.input_schema()} + assert "objective_scorer" not in names + + def test_no_role_for_scenario_result_id(self): + """``scenario_result_id`` is resume plumbing, never a wizard role.""" + names = {role.name for role in BroadSweepThenDeepDive.input_schema()} + assert "scenario_result_id" not in names + + def test_descriptions_are_non_empty(self): + for role in BroadSweepThenDeepDive.input_schema(): + assert role.description, f"role {role.name!r} has empty description" + + +class TestGraphArtifactOpaque: + """8b-2 invariant: both branching-step classes flag themselves as artifact-opaque.""" + + def test_category_aggregating_sweep_step_is_opaque(self): + assert CategoryAggregatingSweepStep.GRAPH_ARTIFACT_OPAQUE is True + + def test_filtered_deep_dive_step_is_opaque(self): + assert FilteredDeepDiveStep.GRAPH_ARTIFACT_OPAQUE is True + + def test_opacity_is_class_level_not_instance_level(self): + """The flag must be readable from the class without instantiation.""" + assert hasattr(CategoryAggregatingSweepStep, "GRAPH_ARTIFACT_OPAQUE") + assert hasattr(FilteredDeepDiveStep, "GRAPH_ARTIFACT_OPAQUE") From 4c8dd94ed1fb1f7c0b9e91c061e882bd5c9267ec Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:21:07 -0700 Subject: [PATCH 20/42] FEAT: Phase 8c scenario builder + input collectors --- pyrit/scenario/__init__.py | 28 +- pyrit/scenario/core/__init__.py | 34 +- pyrit/scenario/core/builder.py | 169 +++++++++ pyrit/scenario/core/input_collector.py | 352 ++++++++++++++++++ tests/unit/scenario/core/__init__.py | 0 tests/unit/scenario/core/test_builder.py | 251 +++++++++++++ .../scenario/core/test_input_collector.py | 319 ++++++++++++++++ 7 files changed, 1148 insertions(+), 5 deletions(-) create mode 100644 pyrit/scenario/core/builder.py create mode 100644 pyrit/scenario/core/input_collector.py create mode 100644 tests/unit/scenario/core/__init__.py create mode 100644 tests/unit/scenario/core/test_builder.py create mode 100644 tests/unit/scenario/core/test_input_collector.py diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index cb113712f..be35b71d7 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -18,26 +18,38 @@ from pyrit.common.parameter import Parameter from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult from pyrit.scenario.core import ( + ArtifactInputCollector, AtomicAttack, AttackTechnique, AttackTechniqueFactory, BaselineAttackPolicy, + CliInputCollector, DatasetConfiguration, + DictInputCollector, + InputCollector, + MaxAttemptsExceededError, + OpaqueRoleNotElicitableError, PolicyAction, RoleDescriptor, RoleTag, Scenario, ScenarioCompositeStrategy, ScenarioCoreState, + ScenarioInputValidationError, ScenarioStateLike, ScenarioStep, ScenarioStepResult, ScenarioStrategy, StrategyGraph, StrategyPolicy, + build_scenario_from_inputs, + collect_inputs_with_retry, + discover_input_schema, + discover_supported_parameters, linear_strategy_policy, policy_to_spec, spec_to_enum, + validate_init_inputs, ) # Import scenario submodules directly and register them as virtual subpackages @@ -63,11 +75,17 @@ foundry = _foundry_module __all__ = [ + "ArtifactInputCollector", "AtomicAttack", "AttackTechnique", "AttackTechniqueFactory", "BaselineAttackPolicy", + "CliInputCollector", "DatasetConfiguration", + "DictInputCollector", + "InputCollector", + "MaxAttemptsExceededError", + "OpaqueRoleNotElicitableError", "Parameter", "PolicyAction", "RoleDescriptor", @@ -75,20 +93,26 @@ "Scenario", "ScenarioCompositeStrategy", "ScenarioCoreState", + "ScenarioIdentifier", + "ScenarioInputValidationError", + "ScenarioResult", "ScenarioStateLike", "ScenarioStep", "ScenarioStepResult", "ScenarioStrategy", - "ScenarioIdentifier", - "ScenarioResult", "StrategyGraph", "StrategyPolicy", "adaptive", "airt", "benchmark", + "build_scenario_from_inputs", + "collect_inputs_with_retry", + "discover_input_schema", + "discover_supported_parameters", "foundry", "garak", "linear_strategy_policy", "policy_to_spec", "spec_to_enum", + "validate_init_inputs", ] diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 063d5508c..d704f620b 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -7,7 +7,23 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory, ScorerOverridePolicy +from pyrit.scenario.core.builder import ( + ScenarioInputValidationError, + build_scenario_from_inputs, + discover_input_schema, + discover_supported_parameters, + validate_init_inputs, +) from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration +from pyrit.scenario.core.input_collector import ( + ArtifactInputCollector, + CliInputCollector, + DictInputCollector, + InputCollector, + MaxAttemptsExceededError, + OpaqueRoleNotElicitableError, + collect_inputs_with_retry, +) from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike @@ -22,20 +38,27 @@ from pyrit.scenario.core.waterfall import policy_to_spec, spec_to_enum __all__ = [ + "ArtifactInputCollector", "AtomicAttack", "AttackTechnique", "AttackTechniqueFactory", "BaselineAttackPolicy", + "CliInputCollector", "DatasetConfiguration", + "DictInputCollector", "EXPLICIT_SEED_GROUPS_KEY", + "InputCollector", + "MaxAttemptsExceededError", + "OpaqueRoleNotElicitableError", + "Parameter", "PolicyAction", "RoleDescriptor", "RoleTag", "SCENARIO_TECHNIQUES", - "Parameter", "Scenario", "ScenarioCompositeStrategy", "ScenarioCoreState", + "ScenarioInputValidationError", "ScenarioStateLike", "ScenarioStep", "ScenarioStepResult", @@ -43,10 +66,15 @@ "ScorerOverridePolicy", "StrategyGraph", "StrategyPolicy", + "build_scenario_from_inputs", + "collect_inputs_with_retry", + "discover_input_schema", + "discover_supported_parameters", + "get_default_adversarial_target", + "get_default_scorer_target", "linear_strategy_policy", "policy_to_spec", "register_scenario_techniques", - "get_default_scorer_target", - "get_default_adversarial_target", "spec_to_enum", + "validate_init_inputs", ] diff --git a/pyrit/scenario/core/builder.py b/pyrit/scenario/core/builder.py new file mode 100644 index 000000000..28fa960ba --- /dev/null +++ b/pyrit/scenario/core/builder.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario builder — module functions for constructing scenarios from declared inputs. + +The builder is the single entry point used by the wizard CLI (Phase 8d), the +artifact loader (Phase 8g), and the inverse-waterfall ``pyrit_scan --from-artifact`` +path (Phase 8e). It validates declared inputs against the scenario's +:meth:`Scenario.input_schema` (rich-object ``__init__`` roles) and +:meth:`Scenario.supported_parameters` (scalar ``initialize_async`` args), then +runs both lifecycle phases and returns a runnable scenario. + +The retry-on-ValidationError loop is intentionally factored out into +:mod:`pyrit.scenario.core.input_collector` so the builder itself stays +collector-agnostic; the CLI driver wires the two together. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + +if TYPE_CHECKING: + from pyrit.common.parameter import Parameter + from pyrit.scenario.core.scenario import Scenario + + +class ScenarioInputValidationError(ValueError): + """ + Raised when supplied init inputs fail validation against ``input_schema()``. + + Carries the role name (when applicable) so the collector retry loop can + re-prompt the specific role rather than restarting collection from scratch. + """ + + def __init__(self, message: str, *, role_name: str | None = None) -> None: + """Store ``message`` and remember which role failed (when known).""" + super().__init__(message) + self.role_name = role_name + + +def discover_input_schema(scenario_cls: type[Scenario]) -> list[RoleDescriptor]: + """ + Return the declared rich-object ``__init__`` roles for the scenario. + + Thin wrapper around :meth:`Scenario.input_schema` so the wizard layer can + treat schema discovery uniformly across all scenarios — including base-class + scenarios that take the default ``[]`` from :class:`Scenario`. + + Args: + scenario_cls: Concrete subclass of :class:`Scenario`. + + Returns: + list[RoleDescriptor]: The role list returned by the class. + """ + return list(scenario_cls.input_schema()) + + +def discover_supported_parameters(scenario_cls: type[Scenario]) -> list[Parameter]: + """ + Return the declared scalar ``initialize_async`` parameters for the scenario. + + Thin wrapper around :meth:`Scenario.supported_parameters` so the wizard layer + can treat both halves of the lifecycle symmetrically. + + Args: + scenario_cls: Concrete subclass of :class:`Scenario`. + + Returns: + list[Parameter]: The parameter list returned by the class. + """ + return list(scenario_cls.supported_parameters()) + + +def validate_init_inputs( + *, + schema: list[RoleDescriptor], + init_inputs: dict[str, Any], +) -> None: + """ + Validate ``init_inputs`` against the scenario's declared ``input_schema``. + + The builder calls this before constructing the scenario so a malformed input + set surfaces as a structured :class:`ScenarioInputValidationError` (with the + offending role name) rather than a downstream constructor ``TypeError``. + + Checks performed: + + * Every required role is present in ``init_inputs``. + * No ``CHOICE`` role's value falls outside its declared choices. + + Type coercion is deliberately left to the collector layer: CLI collectors + coerce strings to typed primitives at elicitation time; programmatic + collectors pass typed values directly; opaque roles bypass coercion entirely. + + Args: + schema: The declared role list (typically from + :func:`discover_input_schema`). + init_inputs: Caller-supplied init inputs keyed by role name. + + Raises: + ScenarioInputValidationError: If a required role is missing or a + choice role's value is not in the declared choices. + """ + schema_by_name = {role.name: role for role in schema} + + for role in schema: + if role.required and role.name not in init_inputs: + raise ScenarioInputValidationError( + f"Missing required init input {role.name!r} ({role.tag.value}). Description: {role.description}", + role_name=role.name, + ) + + for name, value in init_inputs.items(): + role = schema_by_name.get(name) + if role is None: + # Unknown keys are not an error: scenarios may accept kwargs that aren't + # explicitly schema-declared (e.g. ``scenario_result_id``). Skip silently. + continue + + if role.tag is RoleTag.CHOICE: + assert role.choices is not None # __post_init__ guarantees this + if value not in role.choices: + raise ScenarioInputValidationError( + f"Invalid value for role {name!r}: {value!r} is not in declared choices {role.choices!r}.", + role_name=name, + ) + + +async def build_scenario_from_inputs( + scenario_cls: type[Scenario], + *, + init_inputs: dict[str, Any], + init_async_inputs: dict[str, Any], +) -> Scenario: + """ + Construct a scenario, run ``initialize_async``, and return the runnable instance. + + This is the single entry point used by the wizard CLI (Phase 8d) and the + artifact loader (Phase 8g). It validates ``init_inputs`` against the + scenario's ``input_schema`` first so a malformed input set surfaces as + :class:`ScenarioInputValidationError` (carrying the offending role name) + rather than a deep ``TypeError`` from the constructor. + + ``init_async_inputs`` are passed verbatim to ``initialize_async``. The + existing :meth:`Scenario.set_params_from_args` machinery already validates + these against ``supported_parameters``, so the builder does not re-validate. + + Args: + scenario_cls: Concrete subclass of :class:`Scenario`. + init_inputs: Rich-object ``__init__`` arguments keyed by role name. All + required roles from ``input_schema()`` must be present. + init_async_inputs: Scalar ``initialize_async`` arguments. Validated by + ``Scenario.set_params_from_args`` via ``supported_parameters()``. + + Returns: + Scenario: An initialized, runnable scenario instance. + + Raises: + ScenarioInputValidationError: If ``init_inputs`` fails validation. + """ + schema = discover_input_schema(scenario_cls) + validate_init_inputs(schema=schema, init_inputs=init_inputs) + + scenario = scenario_cls(**init_inputs) + await scenario.initialize_async(**init_async_inputs) + return scenario diff --git a/pyrit/scenario/core/input_collector.py b/pyrit/scenario/core/input_collector.py new file mode 100644 index 000000000..ad36625f1 --- /dev/null +++ b/pyrit/scenario/core/input_collector.py @@ -0,0 +1,352 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Input collectors — elicit ``init_inputs`` from various sources for the wizard. + +The :class:`InputCollector` Protocol is the front-end abstraction: a CLI +implementation reads from stdin, a programmatic one reads from a dict, and an +artifact-replay one reads from a previously saved +:class:`pyrit.scenario.core.graph_artifact.GraphArtifact`. The +:func:`collect_inputs_with_retry` helper drives the error-recovery loop with +the builder's :class:`pyrit.scenario.core.builder.ScenarioInputValidationError`. + +Opaque roles (pre-built ``Identifiable`` instances) cannot be elicited from a +pure CLI flow — the :class:`CliInputCollector` reference impl raises +:class:`OpaqueRoleNotElicitableError` with a pointer at ``--from-artifact`` +when it encounters one. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from pyrit.scenario.core.builder import ScenarioInputValidationError +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + +class OpaqueRoleNotElicitableError(NotImplementedError): + """Raised when a CLI collector encounters an OPAQUE role it cannot elicit.""" + + def __init__(self, role_name: str) -> None: + """Format a help message pointing at the artifact-replay workaround.""" + super().__init__( + f"Role {role_name!r} is OPAQUE (a pre-built Identifiable instance) and cannot " + "be elicited from a CLI prompt. Provide it programmatically via " + "``build_scenario_from_inputs(..., init_inputs={...})`` or replay it from a " + "saved graph artifact via ``pyrit_scan --from-artifact path.yaml``." + ) + self.role_name = role_name + + +class MaxAttemptsExceededError(RuntimeError): + """Raised when a collector loop exceeds the per-role retry budget.""" + + def __init__(self, role_name: str, attempts: int, last_error: Exception | None) -> None: + """Format a diagnostic message including the last seen validation error.""" + message = ( + f"Exceeded {attempts} elicitation attempts for role {role_name!r}. Last error: {last_error!r}" + if last_error + else f"Exceeded {attempts} elicitation attempts for role {role_name!r}." + ) + super().__init__(message) + self.role_name = role_name + self.attempts = attempts + self.last_error = last_error + + +@runtime_checkable +class InputCollector(Protocol): + """ + Protocol for eliciting a single role's value. + + Implementations may be sync (return immediately) or wrap async work; the + wizard driver awaits each ``collect`` call so both forms are supported. + The ``error`` and ``attempt`` arguments enable the retry loop in + :func:`collect_inputs_with_retry`: on a validation failure the driver + re-invokes ``collect`` with the prior exception and an incremented attempt + counter, so collectors can re-prompt with the error context visible to the + user. + """ + + def collect( + self, + *, + role: RoleDescriptor, + error: Exception | None = None, + attempt: int = 0, + ) -> Any: + """Return the value for ``role`` (see implementations for source-specific behavior).""" + ... + + +class DictInputCollector: + """ + Programmatic / test collector that returns values from a supplied dict. + + Mainly used in unit tests and by the wizard library API. Missing required + roles raise :class:`ScenarioInputValidationError` (consistent with the + builder's surface) so the test author sees the same error a CLI user would. + """ + + def __init__(self, values: Mapping[str, Any]) -> None: + """ + Initialize the collector. + + Args: + values: Mapping from role name to value. Roles absent from the + mapping return their declared default if optional, or raise + :class:`ScenarioInputValidationError` if required. + """ + self._values = dict(values) + + def collect( + self, + *, + role: RoleDescriptor, + error: Exception | None = None, + attempt: int = 0, + ) -> Any: + """ + Return the value for ``role`` from the supplied dict. + + Returns: + Any: The value bound to ``role.name`` in the dict, or ``role.default`` + if the role is optional and absent. + + Raises: + ScenarioInputValidationError: If the role is required and absent from + the dict. + """ + if role.name in self._values: + return self._values[role.name] + if role.required: + raise ScenarioInputValidationError( + f"DictInputCollector has no value for required role {role.name!r}.", + role_name=role.name, + ) + return role.default + + +class CliInputCollector: + """ + Stdin/stdout reference collector for interactive elicitation. + + Reads one role at a time via the standard :func:`input` builtin (overridable + via ``input_fn`` for testability). Coerces strings to the role's declared + ``param_type`` for ``SCALAR`` roles and validates membership for ``CHOICE`` + roles. Refuses to elicit ``OPAQUE`` roles (see + :class:`OpaqueRoleNotElicitableError`). + + Error-recovery contract: when invoked with a non-None ``error``, the + collector prefixes the re-prompt with the error message so the user sees + what went wrong before re-entering. + """ + + def __init__( + self, + *, + input_fn: Callable[[str], str] | None = None, + output_fn: Callable[[str], None] | None = None, + ) -> None: + """ + Initialize the CLI collector. + + Args: + input_fn: Function used to read a line from the user. Defaults to + the builtin :func:`input`. Override for testing. + output_fn: Function used to write a prompt line. Defaults to + :func:`print`. Override for testing. + """ + self._input = input_fn or input + self._print = output_fn or (lambda msg: print(msg)) + + def collect( + self, + *, + role: RoleDescriptor, + error: Exception | None = None, + attempt: int = 0, + ) -> Any: + """ + Prompt the user for the role's value and return it (coerced if applicable). + + Args: + role: The role to elicit. + error: Previous validation error, if any, surfaced to the user. + attempt: Current attempt count (0 = first prompt). + + Returns: + The collected (and coerced) value. + + Raises: + OpaqueRoleNotElicitableError: If ``role.tag is RoleTag.OPAQUE``. + ScenarioInputValidationError: If the user enters a value that + fails coercion or choice-membership validation. + """ + if role.tag is RoleTag.OPAQUE: + raise OpaqueRoleNotElicitableError(role.name) + + if error is not None: + self._print(f" ! Previous attempt failed: {error}") + + prompt = self._render_prompt(role=role, attempt=attempt) + raw = self._input(prompt).strip() + + if raw == "" and not role.required: + return role.default + + if role.tag is RoleTag.CHOICE: + return self._validate_choice(role=role, raw=raw) + + if role.tag is RoleTag.SCALAR and role.param_type is not None: + return self._coerce_scalar(role=role, raw=raw) + + # SCALAR with no param_type, REGISTRY_REF, FACTORY — return string as-is. + # The builder / downstream code is responsible for any deeper coercion. + return raw + + def _render_prompt(self, *, role: RoleDescriptor, attempt: int) -> str: + bits = [f"{role.name}"] + if role.description: + bits.append(f" ({role.description})") + if role.tag is RoleTag.CHOICE and role.choices is not None: + bits.append(f" [choices: {', '.join(map(str, role.choices))}]") + if not role.required and role.default is not None: + bits.append(f" [default: {role.default!r}]") + elif not role.required: + bits.append(" [optional]") + prefix = f"(attempt {attempt + 1}) " if attempt > 0 else "" + return f"{prefix}{''.join(bits)}: " + + def _validate_choice(self, *, role: RoleDescriptor, raw: str) -> Any: + assert role.choices is not None + # Compare raw string against str(choice); accept the original-typed choice on match. + for choice in role.choices: + if str(choice) == raw or choice == raw: + return choice + raise ScenarioInputValidationError( + f"Invalid choice for role {role.name!r}: {raw!r} not in {role.choices!r}.", + role_name=role.name, + ) + + def _coerce_scalar(self, *, role: RoleDescriptor, raw: str) -> Any: + target = role.param_type + try: + if target is bool: + if raw.lower() in {"true", "yes", "y", "1"}: + return True + if raw.lower() in {"false", "no", "n", "0"}: + return False + raise ValueError(f"Cannot interpret {raw!r} as bool.") + if target is int: + return int(raw) + if target is float: + return float(raw) + if target is str: + return raw + except (TypeError, ValueError) as exc: + raise ScenarioInputValidationError( + f"Failed to coerce {raw!r} to {target!r} for role {role.name!r}: {exc}", + role_name=role.name, + ) from exc + # Unrecognized param_type — pass through as string. + return raw + + +class ArtifactInputCollector: + """ + Collector that replays init_inputs from a previously saved graph artifact. + + Used by ``pyrit_scan --from-artifact`` (Phase 8e) and by the artifact + loader (Phase 8g). Looks each role up by name in the artifact's + ``init_inputs`` dict. Opaque roles surface as their stored + ``ComponentIdentifier.to_dict()`` payload — the load path is responsible + for re-materializing them into live instances. + """ + + def __init__(self, init_inputs: Mapping[str, Any]) -> None: + """ + Initialize the collector. + + Args: + init_inputs: The ``init_inputs`` map from a saved + :class:`pyrit.scenario.core.graph_artifact.GraphArtifact`. + """ + self._init_inputs = dict(init_inputs) + + def collect( + self, + *, + role: RoleDescriptor, + error: Exception | None = None, + attempt: int = 0, + ) -> Any: + """ + Return the recorded value for ``role`` from the artifact. + + Returns: + Any: The value recorded under ``role.name`` in the artifact, or + ``role.default`` if the role is optional and absent. + + Raises: + ScenarioInputValidationError: If the role is required and absent from + the artifact. + """ + if role.name in self._init_inputs: + return self._init_inputs[role.name] + if role.required: + raise ScenarioInputValidationError( + f"Artifact has no recorded value for required role {role.name!r}.", + role_name=role.name, + ) + return role.default + + +def collect_inputs_with_retry( + *, + collector: InputCollector, + schema: list[RoleDescriptor], + max_attempts: int = 5, +) -> dict[str, Any]: + """ + Drive an :class:`InputCollector` to gather a full init-input set. + + Iterates the schema in declared order. For each role, calls + :meth:`InputCollector.collect`. If the call raises + :class:`ScenarioInputValidationError`, the loop retries up to + ``max_attempts`` times with the prior error passed back to the collector + so it can re-prompt with context. After exhausting the budget, raises + :class:`MaxAttemptsExceededError`. + + Args: + collector: Any :class:`InputCollector` implementation. + schema: The schema to collect against (typically from + :func:`pyrit.scenario.core.builder.discover_input_schema`). + max_attempts: Maximum attempts per role before giving up. Defaults to 5. + + Returns: + dict[str, Any]: Collected values keyed by role name. Optional roles + absent from the collector are omitted (the scenario constructor + then uses its own default). + + Raises: + MaxAttemptsExceededError: When the per-role retry budget is exhausted. + """ + collected: dict[str, Any] = {} + for role in schema: + last_error: Exception | None = None + for attempt in range(max_attempts): + try: + value = collector.collect(role=role, error=last_error, attempt=attempt) + except ScenarioInputValidationError as exc: + last_error = exc + continue + collected[role.name] = value + break + else: + raise MaxAttemptsExceededError(role.name, max_attempts, last_error) + return collected diff --git a/tests/unit/scenario/core/__init__.py b/tests/unit/scenario/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/scenario/core/test_builder.py b/tests/unit/scenario/core/test_builder.py new file mode 100644 index 000000000..302e1948e --- /dev/null +++ b/tests/unit/scenario/core/test_builder.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Phase 8c — coverage for the scenario builder module functions.""" + +from __future__ import annotations + +from typing import Any, cast + +import pytest + +from pyrit.common.parameter import Parameter +from pyrit.scenario.core.builder import ( + ScenarioInputValidationError, + build_scenario_from_inputs, + discover_input_schema, + discover_supported_parameters, + validate_init_inputs, +) +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + + +class _FakeScenarioBase: + """Minimal duck-typed stand-in for ``Scenario`` to exercise the builder. + + The builder only touches ``input_schema`` and ``supported_parameters`` at + discovery time and ``__init__`` / ``initialize_async`` at build time, so we + deliberately avoid the real ``Scenario`` heavyweight constructor (memory, + identifier, deprecation machinery). + """ + + _schema: list[RoleDescriptor] = [] + _params: list[Parameter] = [] + + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return cls._schema + + @classmethod + def supported_parameters(cls) -> list[Parameter]: + return cls._params + + +class _FakeScenarioNoArgs(_FakeScenarioBase): + def __init__(self) -> None: + self.init_called = True + self.init_async_called = False + + async def initialize_async(self) -> None: + self.init_async_called = True + + +class _FakeScenarioScalarRoles(_FakeScenarioBase): + _schema = [ + RoleDescriptor(name="weakness_label", description="Label", tag=RoleTag.SCALAR, param_type=str), + RoleDescriptor( + name="threshold", + description="Score cutoff", + tag=RoleTag.SCALAR, + param_type=float, + default=0.5, + required=False, + ), + ] + _params = [ + Parameter(name="max_concurrency", description="Concurrency", default=1, param_type=int), + ] + + def __init__(self, *, weakness_label: str, threshold: float = 0.5) -> None: + self.weakness_label = weakness_label + self.threshold = threshold + self.init_async_max_concurrency: int | None = None + + async def initialize_async(self, *, max_concurrency: int = 1) -> None: + self.init_async_max_concurrency = max_concurrency + + +class _FakeScenarioChoice(_FakeScenarioBase): + _schema = [ + RoleDescriptor( + name="mode", + description="Pick a mode", + tag=RoleTag.CHOICE, + param_type=str, + choices=("fast", "thorough"), + ), + ] + + def __init__(self, *, mode: str) -> None: + self.mode = mode + + async def initialize_async(self) -> None: + pass + + +class _FakeScenarioRaises(_FakeScenarioBase): + _schema = [ + RoleDescriptor(name="value", description="anything", tag=RoleTag.SCALAR, param_type=int), + ] + + def __init__(self, *, value: int) -> None: + if value < 0: + raise ValueError(f"value must be non-negative; got {value}") + self.value = value + + async def initialize_async(self) -> None: + pass + + +class _FakeScenarioInitAsyncRaises(_FakeScenarioBase): + def __init__(self) -> None: + pass + + async def initialize_async(self) -> None: + raise RuntimeError("initialize_async failed") + + +class TestDiscoverInputSchema: + def test_returns_list_copy(self): + schema = discover_input_schema(cast("Any", _FakeScenarioScalarRoles)) + assert isinstance(schema, list) + assert len(schema) == 2 + assert schema[0].name == "weakness_label" + + def test_empty_schema_default(self): + schema = discover_input_schema(cast("Any", _FakeScenarioNoArgs)) + assert schema == [] + + def test_discover_does_not_share_mutable_list(self): + """Mutating the returned list does not affect subsequent calls.""" + schema_a = discover_input_schema(cast("Any", _FakeScenarioScalarRoles)) + schema_a.clear() + schema_b = discover_input_schema(cast("Any", _FakeScenarioScalarRoles)) + assert len(schema_b) == 2 + + +class TestDiscoverSupportedParameters: + def test_returns_list(self): + params = discover_supported_parameters(cast("Any", _FakeScenarioScalarRoles)) + assert isinstance(params, list) + assert len(params) == 1 + assert params[0].name == "max_concurrency" + + def test_empty_when_unset(self): + params = discover_supported_parameters(cast("Any", _FakeScenarioNoArgs)) + assert params == [] + + +class TestValidateInitInputs: + def test_all_required_present_passes(self): + validate_init_inputs(schema=_FakeScenarioScalarRoles._schema, init_inputs={"weakness_label": "harm"}) + + def test_missing_required_raises_with_role_name(self): + with pytest.raises(ScenarioInputValidationError) as exc_info: + validate_init_inputs(schema=_FakeScenarioScalarRoles._schema, init_inputs={}) + assert exc_info.value.role_name == "weakness_label" + assert "weakness_label" in str(exc_info.value) + + def test_missing_optional_passes(self): + """Optional role absence is not a validation failure.""" + validate_init_inputs(schema=_FakeScenarioScalarRoles._schema, init_inputs={"weakness_label": "x"}) + + def test_choice_value_in_choices_passes(self): + validate_init_inputs(schema=_FakeScenarioChoice._schema, init_inputs={"mode": "fast"}) + + def test_choice_value_not_in_choices_raises(self): + with pytest.raises(ScenarioInputValidationError) as exc_info: + validate_init_inputs(schema=_FakeScenarioChoice._schema, init_inputs={"mode": "instant"}) + assert exc_info.value.role_name == "mode" + assert "instant" in str(exc_info.value) + + def test_unknown_keys_pass_through_silently(self): + """Scenarios may accept kwargs not in the schema (e.g. scenario_result_id).""" + validate_init_inputs( + schema=_FakeScenarioScalarRoles._schema, + init_inputs={"weakness_label": "x", "scenario_result_id": "abc"}, + ) + + def test_empty_schema_accepts_any_inputs(self): + validate_init_inputs(schema=[], init_inputs={"whatever": 1}) + + +class TestBuildScenarioFromInputs: + async def test_constructs_and_initializes(self): + scenario = await build_scenario_from_inputs( + cast("Any", _FakeScenarioScalarRoles), + init_inputs={"weakness_label": "harm"}, + init_async_inputs={"max_concurrency": 4}, + ) + assert scenario.weakness_label == "harm" # type: ignore[attr-defined] + assert scenario.threshold == 0.5 # type: ignore[attr-defined] + assert scenario.init_async_max_concurrency == 4 # type: ignore[attr-defined] + + async def test_no_args_scenario(self): + scenario = await build_scenario_from_inputs( + cast("Any", _FakeScenarioNoArgs), + init_inputs={}, + init_async_inputs={}, + ) + assert scenario.init_async_called is True # type: ignore[attr-defined] + + async def test_validation_runs_before_construction(self): + """A missing required role raises before ``__init__`` is reached.""" + with pytest.raises(ScenarioInputValidationError) as exc_info: + await build_scenario_from_inputs( + cast("Any", _FakeScenarioScalarRoles), + init_inputs={}, + init_async_inputs={}, + ) + assert exc_info.value.role_name == "weakness_label" + + async def test_construction_errors_propagate(self): + """A ``__init__`` exception is not wrapped — caller gets the original.""" + with pytest.raises(ValueError, match="value must be non-negative"): + await build_scenario_from_inputs( + cast("Any", _FakeScenarioRaises), + init_inputs={"value": -1}, + init_async_inputs={}, + ) + + async def test_initialize_async_errors_propagate(self): + with pytest.raises(RuntimeError, match="initialize_async failed"): + await build_scenario_from_inputs( + cast("Any", _FakeScenarioInitAsyncRaises), + init_inputs={}, + init_async_inputs={}, + ) + + async def test_choice_validation_fires(self): + with pytest.raises(ScenarioInputValidationError) as exc_info: + await build_scenario_from_inputs( + cast("Any", _FakeScenarioChoice), + init_inputs={"mode": "bogus"}, + init_async_inputs={}, + ) + assert exc_info.value.role_name == "mode" + + +class TestScenarioInputValidationError: + def test_is_value_error_subclass(self): + """Lets callers catch with the broader ``ValueError`` if they want.""" + assert issubclass(ScenarioInputValidationError, ValueError) + + def test_role_name_defaults_to_none(self): + exc = ScenarioInputValidationError("bare message") + assert exc.role_name is None + + def test_role_name_round_trips(self): + exc = ScenarioInputValidationError("oops", role_name="x") + assert exc.role_name == "x" + assert str(exc) == "oops" diff --git a/tests/unit/scenario/core/test_input_collector.py b/tests/unit/scenario/core/test_input_collector.py new file mode 100644 index 000000000..c03fa433c --- /dev/null +++ b/tests/unit/scenario/core/test_input_collector.py @@ -0,0 +1,319 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Phase 8c — coverage for the ``InputCollector`` Protocol and reference impls.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from pyrit.scenario.core.builder import ScenarioInputValidationError +from pyrit.scenario.core.input_collector import ( + ArtifactInputCollector, + CliInputCollector, + DictInputCollector, + InputCollector, + MaxAttemptsExceededError, + OpaqueRoleNotElicitableError, + collect_inputs_with_retry, +) +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + +# --- shared fixtures ------------------------------------------------------------- + + +def _scalar_role( + *, + name: str = "label", + description: str = "desc", + param_type: type | None = str, + required: bool = True, + default: Any = None, +) -> RoleDescriptor: + return RoleDescriptor( + name=name, + description=description, + tag=RoleTag.SCALAR, + param_type=param_type, + required=required, + default=default, + ) + + +def _choice_role(*, name: str = "mode", choices: tuple[str, ...] = ("fast", "slow")) -> RoleDescriptor: + return RoleDescriptor( + name=name, + description="pick one", + tag=RoleTag.CHOICE, + param_type=str, + choices=choices, + ) + + +def _opaque_role(*, name: str = "instance") -> RoleDescriptor: + return RoleDescriptor( + name=name, + description="pre-built thing", + tag=RoleTag.OPAQUE, + ) + + +# --- DictInputCollector ---------------------------------------------------------- + + +class TestDictInputCollector: + def test_returns_value_when_present(self): + collector = DictInputCollector({"label": "harm"}) + assert collector.collect(role=_scalar_role()) == "harm" + + def test_missing_required_raises_with_role_name(self): + collector = DictInputCollector({}) + with pytest.raises(ScenarioInputValidationError) as exc_info: + collector.collect(role=_scalar_role(name="x")) + assert exc_info.value.role_name == "x" + + def test_missing_optional_returns_default(self): + collector = DictInputCollector({}) + role = _scalar_role(required=False, default="defaulted") + assert collector.collect(role=role) == "defaulted" + + def test_defensive_copy_of_input_mapping(self): + """Mutating the source dict after construction does not affect the collector.""" + source: dict[str, Any] = {"label": "first"} + collector = DictInputCollector(source) + source["label"] = "second" + assert collector.collect(role=_scalar_role()) == "first" + + def test_collector_implements_protocol(self): + assert isinstance(DictInputCollector({}), InputCollector) + + +# --- CliInputCollector ----------------------------------------------------------- + + +class _ScriptedInput: + """Helper: returns a series of canned responses, panics if exhausted.""" + + def __init__(self, responses: list[str]) -> None: + self._responses = list(responses) + self.prompts: list[str] = [] + + def __call__(self, prompt: str) -> str: + self.prompts.append(prompt) + if not self._responses: + raise AssertionError(f"Scripted input exhausted; got extra prompt {prompt!r}") + return self._responses.pop(0) + + +class _Capture: + def __init__(self) -> None: + self.lines: list[str] = [] + + def __call__(self, msg: str) -> None: + self.lines.append(msg) + + +class TestCliInputCollectorBasic: + def test_scalar_string_value(self): + scripted = _ScriptedInput(["harm"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + assert collector.collect(role=_scalar_role()) == "harm" + assert "label" in scripted.prompts[0] + + def test_scalar_int_coercion(self): + scripted = _ScriptedInput(["42"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + assert collector.collect(role=_scalar_role(param_type=int)) == 42 + + def test_scalar_float_coercion(self): + scripted = _ScriptedInput(["3.14"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + assert collector.collect(role=_scalar_role(param_type=float)) == 3.14 + + @pytest.mark.parametrize( + "raw,expected", [("true", True), ("yes", True), ("1", True), ("false", False), ("n", False), ("0", False)] + ) + def test_scalar_bool_coercion(self, raw, expected): + scripted = _ScriptedInput([raw]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + assert collector.collect(role=_scalar_role(param_type=bool)) is expected + + def test_scalar_bool_invalid_raises_validation_error(self): + scripted = _ScriptedInput(["maybe"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + with pytest.raises(ScenarioInputValidationError) as exc_info: + collector.collect(role=_scalar_role(param_type=bool)) + assert exc_info.value.role_name == "label" + + def test_scalar_int_invalid_raises_validation_error(self): + scripted = _ScriptedInput(["not-a-number"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + with pytest.raises(ScenarioInputValidationError): + collector.collect(role=_scalar_role(param_type=int)) + + +class TestCliInputCollectorChoice: + def test_valid_choice(self): + scripted = _ScriptedInput(["fast"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + assert collector.collect(role=_choice_role()) == "fast" + + def test_invalid_choice_raises(self): + scripted = _ScriptedInput(["bogus"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + with pytest.raises(ScenarioInputValidationError) as exc_info: + collector.collect(role=_choice_role()) + assert exc_info.value.role_name == "mode" + + def test_choice_prompt_lists_choices(self): + scripted = _ScriptedInput(["fast"]) + capture = _Capture() + collector = CliInputCollector(input_fn=scripted, output_fn=capture) + collector.collect(role=_choice_role()) + assert "fast" in scripted.prompts[0] + assert "slow" in scripted.prompts[0] + + +class TestCliInputCollectorOpaque: + def test_opaque_role_raises_clear_error(self): + collector = CliInputCollector(input_fn=_ScriptedInput([]), output_fn=_Capture()) + with pytest.raises(OpaqueRoleNotElicitableError) as exc_info: + collector.collect(role=_opaque_role(name="atomic_attack")) + assert exc_info.value.role_name == "atomic_attack" + assert "from-artifact" in str(exc_info.value) + + +class TestCliInputCollectorOptional: + def test_blank_response_uses_default(self): + scripted = _ScriptedInput([""]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + role = _scalar_role(required=False, default="defaulted") + assert collector.collect(role=role) == "defaulted" + + def test_blank_response_required_returns_empty_string(self): + """Empty response on a required role falls through to coercion (or no coercion).""" + scripted = _ScriptedInput([""]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + # str role, no coercion needed — empty string is returned as-is for required str roles. + assert collector.collect(role=_scalar_role(param_type=str)) == "" + + +class TestCliInputCollectorErrorRecovery: + def test_error_argument_renders_before_prompt(self): + scripted = _ScriptedInput(["harm"]) + capture = _Capture() + collector = CliInputCollector(input_fn=scripted, output_fn=capture) + prior_error = ScenarioInputValidationError("bad value", role_name="label") + collector.collect(role=_scalar_role(), error=prior_error, attempt=1) + assert any("Previous attempt failed" in line for line in capture.lines) + assert any("bad value" in line for line in capture.lines) + + def test_attempt_counter_shows_in_prompt(self): + scripted = _ScriptedInput(["x"]) + collector = CliInputCollector(input_fn=scripted, output_fn=_Capture()) + collector.collect(role=_scalar_role(), attempt=2) + assert "attempt 3" in scripted.prompts[0] + + +# --- ArtifactInputCollector ------------------------------------------------------ + + +class TestArtifactInputCollector: + def test_replays_value(self): + collector = ArtifactInputCollector({"label": "from-artifact"}) + assert collector.collect(role=_scalar_role()) == "from-artifact" + + def test_missing_required_raises(self): + collector = ArtifactInputCollector({}) + with pytest.raises(ScenarioInputValidationError) as exc_info: + collector.collect(role=_scalar_role()) + assert exc_info.value.role_name == "label" + + def test_missing_optional_returns_default(self): + collector = ArtifactInputCollector({}) + role = _scalar_role(required=False, default="d") + assert collector.collect(role=role) == "d" + + def test_opaque_payload_passes_through(self): + """Opaque values stored in the artifact (as ``ComponentIdentifier.to_dict()``) replay verbatim.""" + opaque_payload = {"class_name": "AtomicAttack", "init_data": {}} + collector = ArtifactInputCollector({"instance": opaque_payload}) + assert collector.collect(role=_opaque_role()) == opaque_payload + + +# --- collect_inputs_with_retry --------------------------------------------------- + + +class _FlakyCollector: + """Collector that fails the first ``fail_n`` attempts, then succeeds.""" + + def __init__(self, *, fail_n: int, success_value: Any) -> None: + self.fail_n = fail_n + self.success_value = success_value + self.attempts_seen: list[int] = [] + self.errors_seen: list[Exception | None] = [] + + def collect(self, *, role: RoleDescriptor, error: Exception | None = None, attempt: int = 0) -> Any: + self.attempts_seen.append(attempt) + self.errors_seen.append(error) + if attempt < self.fail_n: + raise ScenarioInputValidationError(f"attempt {attempt} failed", role_name=role.name) + return self.success_value + + +class TestCollectInputsWithRetrySuccess: + def test_single_role_first_try(self): + collector = DictInputCollector({"label": "v"}) + result = collect_inputs_with_retry(collector=collector, schema=[_scalar_role()]) + assert result == {"label": "v"} + + def test_multiple_roles_in_declared_order(self): + schema = [_scalar_role(name="a"), _scalar_role(name="b"), _scalar_role(name="c")] + collector = DictInputCollector({"a": 1, "b": 2, "c": 3}) + result = collect_inputs_with_retry(collector=collector, schema=schema) + assert list(result.keys()) == ["a", "b", "c"] + + def test_optional_missing_role_omitted_from_result(self): + """If a collector returns the default, the role is still in the result map.""" + schema = [_scalar_role(name="a", required=False, default="d")] + collector = DictInputCollector({}) + result = collect_inputs_with_retry(collector=collector, schema=schema) + assert result == {"a": "d"} + + +class TestCollectInputsWithRetryRetries: + def test_succeeds_on_third_attempt(self): + collector = _FlakyCollector(fail_n=2, success_value="ok") + result = collect_inputs_with_retry(collector=collector, schema=[_scalar_role()], max_attempts=5) + assert result == {"label": "ok"} + assert collector.attempts_seen == [0, 1, 2] + + def test_error_propagates_to_next_attempt(self): + collector = _FlakyCollector(fail_n=1, success_value="ok") + collect_inputs_with_retry(collector=collector, schema=[_scalar_role()], max_attempts=5) + assert collector.errors_seen[0] is None + assert isinstance(collector.errors_seen[1], ScenarioInputValidationError) + assert collector.errors_seen[1].role_name == "label" # type: ignore[union-attr] + + def test_exhausts_attempts_and_raises(self): + collector = _FlakyCollector(fail_n=10, success_value="never") + with pytest.raises(MaxAttemptsExceededError) as exc_info: + collect_inputs_with_retry(collector=collector, schema=[_scalar_role()], max_attempts=3) + assert exc_info.value.role_name == "label" + assert exc_info.value.attempts == 3 + assert isinstance(exc_info.value.last_error, ScenarioInputValidationError) + # Tried exactly max_attempts times (no extra retry past the budget). + assert collector.attempts_seen == [0, 1, 2] + + +class TestOpaqueRoleNotElicitableError: + def test_message_points_at_artifact_path(self): + exc = OpaqueRoleNotElicitableError("atomic_attack") + assert "atomic_attack" in str(exc) + assert "from-artifact" in str(exc) + + def test_is_not_implemented_error(self): + """Allows callers to use the broader ``NotImplementedError`` catch.""" + assert issubclass(OpaqueRoleNotElicitableError, NotImplementedError) From d6148fe6796f1905cd29deb8ee33fe5b518467fb Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:45:52 -0700 Subject: [PATCH 21/42] FEAT: Phase 8g graph artifact + load contract --- pyrit/registry/discovery.py | 16 +- pyrit/scenario/__init__.py | 22 + pyrit/scenario/core/__init__.py | 24 + pyrit/scenario/core/graph_artifact.py | 545 ++++++++++++++++++ .../registry/test_discovery_resilience.py | 89 +++ .../unit/scenario/core/test_graph_artifact.py | 485 ++++++++++++++++ 6 files changed, 1177 insertions(+), 4 deletions(-) create mode 100644 pyrit/scenario/core/graph_artifact.py create mode 100644 tests/unit/registry/test_discovery_resilience.py create mode 100644 tests/unit/scenario/core/test_graph_artifact.py diff --git a/pyrit/registry/discovery.py b/pyrit/registry/discovery.py index 5df0c14fe..6d1eb165d 100644 --- a/pyrit/registry/discovery.py +++ b/pyrit/registry/discovery.py @@ -130,10 +130,18 @@ def name_builder(prefix: str, name: str) -> str: # For non-package modules, find and yield subclasses if not is_pkg: for _name, obj in inspect.getmembers(module, inspect.isclass): - if issubclass(obj, base_class) and obj is not base_class and not inspect.isabstract(obj): - # Build the registry name including any prefix - registry_name = name_builder(_prefix, module_name) - yield (registry_name, obj) + # ``inspect.isclass`` returns True for parameterized type aliases + # like ``Callable[[X], Y]`` that are not real classes; guard + # ``issubclass`` so a single such alias doesn't terminate + # discovery for the whole module. + try: + if not (issubclass(obj, base_class) and obj is not base_class and not inspect.isabstract(obj)): + continue + except TypeError: + continue + # Build the registry name including any prefix + registry_name = name_builder(_prefix, module_name) + yield (registry_name, obj) # Recursively discover in subpackages if recursive and is_pkg: diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index be35b71d7..f5e9c24c0 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -26,8 +26,13 @@ CliInputCollector, DatasetConfiguration, DictInputCollector, + GraphArtifact, + GraphArtifactDriftError, + GraphArtifactError, + GraphArtifactSecurityError, InputCollector, MaxAttemptsExceededError, + OpaqueInputUnresolvedError, OpaqueRoleNotElicitableError, PolicyAction, RoleDescriptor, @@ -42,11 +47,17 @@ ScenarioStrategy, StrategyGraph, StrategyPolicy, + build_graph_artifact, build_scenario_from_inputs, + build_topology_summary, collect_inputs_with_retry, discover_input_schema, discover_supported_parameters, + graph_artifact_from_yaml, + graph_artifact_to_yaml, linear_strategy_policy, + load_scenario_from_artifact, + materialize_opaque_inputs, policy_to_spec, spec_to_enum, validate_init_inputs, @@ -83,8 +94,13 @@ "CliInputCollector", "DatasetConfiguration", "DictInputCollector", + "GraphArtifact", + "GraphArtifactDriftError", + "GraphArtifactError", + "GraphArtifactSecurityError", "InputCollector", "MaxAttemptsExceededError", + "OpaqueInputUnresolvedError", "OpaqueRoleNotElicitableError", "Parameter", "PolicyAction", @@ -105,13 +121,19 @@ "adaptive", "airt", "benchmark", + "build_graph_artifact", "build_scenario_from_inputs", + "build_topology_summary", "collect_inputs_with_retry", "discover_input_schema", "discover_supported_parameters", "foundry", "garak", + "graph_artifact_from_yaml", + "graph_artifact_to_yaml", "linear_strategy_policy", + "load_scenario_from_artifact", + "materialize_opaque_inputs", "policy_to_spec", "spec_to_enum", "validate_init_inputs", diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index d704f620b..8f6406826 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -15,6 +15,19 @@ validate_init_inputs, ) from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration +from pyrit.scenario.core.graph_artifact import ( + GraphArtifact, + GraphArtifactDriftError, + GraphArtifactError, + GraphArtifactSecurityError, + OpaqueInputUnresolvedError, + build_graph_artifact, + build_topology_summary, + graph_artifact_from_yaml, + graph_artifact_to_yaml, + load_scenario_from_artifact, + materialize_opaque_inputs, +) from pyrit.scenario.core.input_collector import ( ArtifactInputCollector, CliInputCollector, @@ -47,8 +60,13 @@ "DatasetConfiguration", "DictInputCollector", "EXPLICIT_SEED_GROUPS_KEY", + "GraphArtifact", + "GraphArtifactDriftError", + "GraphArtifactError", + "GraphArtifactSecurityError", "InputCollector", "MaxAttemptsExceededError", + "OpaqueInputUnresolvedError", "OpaqueRoleNotElicitableError", "Parameter", "PolicyAction", @@ -66,13 +84,19 @@ "ScorerOverridePolicy", "StrategyGraph", "StrategyPolicy", + "build_graph_artifact", "build_scenario_from_inputs", + "build_topology_summary", "collect_inputs_with_retry", "discover_input_schema", "discover_supported_parameters", "get_default_adversarial_target", "get_default_scorer_target", + "graph_artifact_from_yaml", + "graph_artifact_to_yaml", "linear_strategy_policy", + "load_scenario_from_artifact", + "materialize_opaque_inputs", "policy_to_spec", "register_scenario_techniques", "spec_to_enum", diff --git a/pyrit/scenario/core/graph_artifact.py b/pyrit/scenario/core/graph_artifact.py new file mode 100644 index 000000000..5014d7e7d --- /dev/null +++ b/pyrit/scenario/core/graph_artifact.py @@ -0,0 +1,545 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Graph artifact — reproducible config capture for Python-authored scenarios. + +A :class:`GraphArtifact` is a fully serializable snapshot of an initialized +scenario's *configuration* (the inputs you'd give the wizard) plus a +*topology snapshot* (what the underlying graph looked like at build time). +Reloading replays the configuration through +:func:`pyrit.scenario.core.builder.build_scenario_from_inputs`, then asserts +the rebuilt graph matches the snapshot. + +Scope (deliberately narrow): this enables ``pyrit_scan --from-artifact path.yaml`` +to reproduce a wizard-built scenario. It does NOT enable authoring new +scenario topologies from YAML; transitions live in Python closures inside +``_build_execution_graph`` and are not serialized. Drift detection is +*structural* (state set, step identifiers) not behavioral (predicate bodies). + +Security: the load path resolves the captured scenario class FQN through +:class:`pyrit.registry.class_registries.scenario_registry.ScenarioRegistry`'s +self-discovered whitelist. Unregistered FQNs are rejected. Registered +scenarios can still run arbitrary code in ``__init__`` — this is acceptable +for PyRIT's trusted-developer threat model and documented here loudly. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import asdict, dataclass, field +from importlib import import_module +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import yaml + +import pyrit +from pyrit.scenario.core.builder import build_scenario_from_inputs +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + from pyrit.prompt_target import PromptTarget + from pyrit.scenario.core.dataset_configuration import DatasetConfiguration + from pyrit.scenario.core.scenario import Scenario + + +_ARTIFACT_VERSION = 1 + + +class GraphArtifactError(Exception): + """Base class for graph-artifact failures.""" + + +class GraphArtifactSecurityError(GraphArtifactError): + """Raised when an artifact's ``scenario_class_fqn`` is not registry-whitelisted.""" + + +class GraphArtifactDriftError(GraphArtifactError): + """Raised when a loaded scenario's rebuilt topology disagrees with the artifact snapshot.""" + + +class OpaqueInputUnresolvedError(GraphArtifactError): + """Raised when an opaque role's stored payload cannot be rematerialized at load time.""" + + def __init__(self, role_name: str, payload: Any) -> None: + """Format a help message naming the role and pointing at ``opaque_materializers``.""" + super().__init__( + f"Opaque role {role_name!r} has a stored identifier payload but no materializer " + "was supplied. Pass ``opaque_materializers={role_name: callable}`` to " + "``load_scenario_from_artifact`` to rebuild the instance from its identifier." + ) + self.role_name = role_name + self.payload = payload + + +# --- helpers --------------------------------------------------------------------- + + +def _canonical_json(obj: Any) -> str: + """Return a deterministic JSON encoding suitable for hashing.""" + return json.dumps(obj, sort_keys=True, separators=(",", ":"), default=str) + + +def _class_fqn(cls: type) -> str: + """Return the canonical ``module.qualname`` FQN for ``cls``.""" + return f"{cls.__module__}.{cls.__qualname__}" + + +def _resolve_scenario_fqn(fqn: str) -> type[Scenario]: + """ + Resolve a captured FQN to a live :class:`Scenario` class. + + Security: the resolved class MUST appear in the discovered + :class:`ScenarioRegistry` to be accepted. Unregistered FQNs are rejected + even if the FQN points at a real Python class. + + Returns: + type[Scenario]: The resolved scenario class. + + Raises: + GraphArtifactSecurityError: If the FQN is not dotted, cannot be + imported, does not resolve to a class, is not a ``Scenario`` + subclass, or is not in the registry whitelist. + """ + from pyrit.registry.class_registries.scenario_registry import ScenarioRegistry + from pyrit.scenario.core.scenario import Scenario + + module_path, _, class_name = fqn.rpartition(".") + if not module_path: + raise GraphArtifactSecurityError(f"Scenario FQN {fqn!r} is not a dotted path.") + + try: + module = import_module(module_path) + cls = getattr(module, class_name, None) + except ImportError as exc: + raise GraphArtifactSecurityError(f"Failed to import module for FQN {fqn!r}: {exc}") from exc + + if cls is None: + raise GraphArtifactSecurityError(f"FQN {fqn!r} did not resolve to a class.") + + if not isinstance(cls, type) or not issubclass(cls, Scenario): + raise GraphArtifactSecurityError(f"FQN {fqn!r} resolved to {cls!r}, which is not a Scenario subclass.") + + registry = ScenarioRegistry() + registry._ensure_discovered() + registered_classes = {entry.registered_class for entry in registry._class_entries.values()} + if cls not in registered_classes: + raise GraphArtifactSecurityError( + f"Scenario class {fqn!r} is not in the registry whitelist. " + "Only registered scenarios can be loaded from artifacts." + ) + + return cls + + +def _serialize_dataset_config(cfg: DatasetConfiguration) -> dict[str, Any]: + """ + Serialize a :class:`DatasetConfiguration` to a JSON-compatible dict. + + Only ``dataset_names`` and ``max_dataset_size`` are captured. Explicit + ``seed_groups`` are NOT serialized in 8g MVP — artifacts built from + explicit-seed-group configs will fail to fully round-trip if the caller + doesn't re-supply equivalent groups via load-time overrides. + + Returns: + dict[str, Any]: A JSON-compatible mapping carrying ``dataset_names``, + ``max_dataset_size``, and a marker count for explicit seed groups. + """ + if cfg is None: + return {} + return { + "dataset_names": list(cfg._dataset_names) if cfg._dataset_names is not None else None, + "max_dataset_size": cfg.max_dataset_size, + "explicit_seed_groups_count": len(cfg._seed_groups) if cfg._seed_groups else 0, + } + + +def _deserialize_dataset_config(payload: Mapping[str, Any]) -> DatasetConfiguration | None: + """ + Reconstruct a :class:`DatasetConfiguration` from a serialized payload. + + Returns: + DatasetConfiguration | None: A reconstructed config, or ``None`` when + the payload is empty (i.e. the original scenario had no dataset + configuration to capture). + + Raises: + GraphArtifactError: If the payload claims explicit seed groups, which + 8g MVP does not support round-tripping. + """ + from pyrit.scenario.core.dataset_configuration import DatasetConfiguration + + if not payload: + return None + if payload.get("explicit_seed_groups_count", 0) > 0: + raise GraphArtifactError( + "Artifact captured a DatasetConfiguration with explicit seed_groups, which 8g MVP " + "does not round-trip. Supply seed_groups via a build-time override or rebuild the " + "artifact from a dataset_names-based configuration." + ) + return DatasetConfiguration( + dataset_names=payload.get("dataset_names"), + max_dataset_size=payload.get("max_dataset_size"), + ) + + +# --- dataclass ------------------------------------------------------------------- + + +@dataclass(frozen=True) +class GraphArtifact: + """ + Reproducible snapshot of an initialized scenario's configuration + topology. + + Attributes: + scenario_class_fqn: ``module.ClassName`` for whitelist-resolved load. + scenario_version: Mirrors the scenario's instance-time ``version=`` arg + (read from ``scenario._identifier.scenario_version``). + pyrit_version: Stamped from ``pyrit.__version__`` at build time. + artifact_version: Bumped only on backward-incompatible schema changes. + init_inputs: Validated against ``input_schema()``. Opaque values are + stored as ``value.get_identifier().to_dict()`` payloads. + init_async_inputs: Scalar arguments forwarded to ``initialize_async``. + Typically the contents of ``self.params``. + scenario_strategies: Enum-member names of the strategies selected at + initialize time (e.g. ``["EASY", "HARD"]``). + dataset_config: Serialized :class:`DatasetConfiguration` (subset). + include_baseline: Resolved boolean from the scenario's effective + :class:`BaselineAttackPolicy`. + params: Snapshot of ``self.params`` after ``set_params_from_args``. + memory_labels: Run-time memory labels. + topology_hash: ``sha256(canonical_json(topology_summary))`` for drift + comparison at load time. + topology_summary: Human-readable structural snapshot (states, terminals, + atomic-attack identifiers). + state_enum_fqn: ``module.Enum`` for branching scenarios that use an + Enum state type; ``None`` for legacy ``int``-keyed linear scenarios. + """ + + scenario_class_fqn: str + scenario_version: int + pyrit_version: str + artifact_version: int = _ARTIFACT_VERSION + + init_inputs: dict[str, Any] = field(default_factory=dict) + init_async_inputs: dict[str, Any] = field(default_factory=dict) + + scenario_strategies: list[str] = field(default_factory=list) + dataset_config: dict[str, Any] = field(default_factory=dict) + include_baseline: bool = False + params: dict[str, Any] = field(default_factory=dict) + memory_labels: dict[str, str] = field(default_factory=dict) + + topology_hash: str = "" + topology_summary: dict[str, Any] = field(default_factory=dict) + state_enum_fqn: str | None = None + + +# --- build path ------------------------------------------------------------------ + + +def build_topology_summary(scenario: Scenario) -> dict[str, Any]: + """ + Produce a deterministic structural snapshot of ``scenario.execution_graph``. + + Captures states, initial / terminal state names, and the atomic-attack + identifier list. Per-state step bindings are NOT recorded (they live in + closure bodies inside the policy actions and cannot be introspected). + Drift checks therefore catch state-set changes, atomic-attack changes, and + policy-initial / terminal changes — not behavioral changes inside + transition predicates. + + Args: + scenario: A scenario that has completed ``initialize_async`` (so + ``execution_graph`` and ``_atomic_attacks`` are populated). + + Returns: + dict[str, Any]: A JSON-compatible mapping suitable for hashing. + + Raises: + ValueError: If the scenario has not been initialized (no execution graph). + """ + if scenario.execution_graph is None: + # Fall back to building the graph from current atomic attacks so callers + # can snapshot a scenario whose run_async hasn't been invoked yet. + scenario._execution_graph = scenario._build_execution_graph(steps=scenario._atomic_attacks) + + graph = scenario.execution_graph + assert graph is not None # narrowed above + policy = graph.policy + + return { + "scenario_class_fqn": _class_fqn(type(scenario)), + "scenario_version": scenario._identifier.version, + "states": sorted([str(state) for state in policy.actions]), + "initial_state": str(policy.initial_state), + "terminal_states": sorted([str(state) for state in policy.terminal_states]), + "atomic_attacks": [atomic.get_identifier().to_dict() for atomic in scenario._atomic_attacks], + } + + +def _topology_hash(summary: Mapping[str, Any]) -> str: + """ + Compute the canonical sha256 hash for a topology summary. + + Returns: + str: The hex digest of ``sha256(_canonical_json(summary))``. + """ + return hashlib.sha256(_canonical_json(summary).encode("utf-8")).hexdigest() + + +def _encode_init_inputs( + *, + schema: list[RoleDescriptor], + init_inputs: Mapping[str, Any], +) -> dict[str, Any]: + """ + Encode init_inputs for serialization, snapshotting OPAQUE roles as identifier dicts. + + Returns: + dict[str, Any]: A serialization-safe mapping with opaque live instances + replaced by their ``ComponentIdentifier.to_dict()`` payloads. + """ + encoded: dict[str, Any] = {} + schema_by_name = {role.name: role for role in schema} + for name, value in init_inputs.items(): + role = schema_by_name.get(name) + if role is not None and role.tag is RoleTag.OPAQUE and value is not None: + if hasattr(value, "get_identifier"): + encoded[name] = value.get_identifier().to_dict() + else: + # Unknown opaque shape — defer to caller serialization at their own risk. + encoded[name] = value + else: + encoded[name] = value + return encoded + + +def build_graph_artifact( + scenario: Scenario, + *, + init_inputs: Mapping[str, Any] | None = None, + init_async_inputs: Mapping[str, Any] | None = None, +) -> GraphArtifact: + """ + Snapshot an initialized scenario as a :class:`GraphArtifact`. + + Args: + scenario: Must have completed ``initialize_async``. + init_inputs: The rich-object ``__init__`` arguments the scenario was + built with. Required because ``Scenario`` does not store its + constructor args directly — they're embedded in opaque attributes + (``objective_scorer``, ``strategy_class``) that aren't easy to + recover after construction. + init_async_inputs: The scalar ``initialize_async`` arguments. Optional; + defaults to ``self.params``. + + Returns: + GraphArtifact: A frozen snapshot ready for YAML serialization. + """ + init_inputs = dict(init_inputs or {}) + init_async_inputs = dict(init_async_inputs or scenario.params or {}) + + schema = list(type(scenario).input_schema()) + encoded_init_inputs = _encode_init_inputs(schema=schema, init_inputs=init_inputs) + + strategy_names = [ + strategy.name if hasattr(strategy, "name") else str(strategy) for strategy in scenario._scenario_strategies + ] + strategy_cls = type(scenario).get_strategy_class() + state_enum_fqn = _class_fqn(strategy_cls) if strategy_cls is not None else None + + topology = build_topology_summary(scenario) + + return GraphArtifact( + scenario_class_fqn=_class_fqn(type(scenario)), + scenario_version=scenario._identifier.version, + pyrit_version=pyrit.__version__, + init_inputs=encoded_init_inputs, + init_async_inputs=init_async_inputs, + scenario_strategies=strategy_names, + dataset_config=_serialize_dataset_config(scenario._dataset_config), + include_baseline=scenario._include_baseline, + params=dict(scenario.params), + memory_labels=dict(scenario._memory_labels), + topology_hash=_topology_hash(topology), + topology_summary=topology, + state_enum_fqn=state_enum_fqn, + ) + + +# --- YAML I/O -------------------------------------------------------------------- + + +def graph_artifact_to_yaml(artifact: GraphArtifact, path: str | Path) -> None: + """ + Write ``artifact`` to ``path`` in deterministic YAML form. + + Two artifacts that compare equal at the dataclass level write byte-identical + YAML thanks to ``sort_keys=True`` and ``default_flow_style=False``. This is + the contract integration tests rely on for the byte-identical-output gate. + """ + payload = asdict(artifact) + Path(path).write_text( + yaml.safe_dump(payload, sort_keys=True, default_flow_style=False), + encoding="utf-8", + ) + + +def graph_artifact_from_yaml(path: str | Path) -> GraphArtifact: + """ + Read a :class:`GraphArtifact` back from ``path``. + + Performs an ``artifact_version`` check so older artifacts surface a clear + error instead of silently mismatching the dataclass schema. + + Returns: + GraphArtifact: The deserialized artifact. + + Raises: + GraphArtifactError: If the YAML payload is not a mapping, or if its + ``artifact_version`` does not match the version this PyRIT + understands. + """ + payload = yaml.safe_load(Path(path).read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise GraphArtifactError(f"Artifact at {path!r} is not a mapping; got {type(payload).__name__}.") + + artifact_version = payload.get("artifact_version", 0) + if artifact_version != _ARTIFACT_VERSION: + raise GraphArtifactError( + f"Artifact at {path!r} has artifact_version={artifact_version}, " + f"but this PyRIT understands artifact_version={_ARTIFACT_VERSION}. " + "Rebuild the artifact with the current version." + ) + + return GraphArtifact(**payload) + + +# --- load path ------------------------------------------------------------------- + + +def materialize_opaque_inputs( + cls: type[Scenario], + init_inputs: Mapping[str, Any], + *, + opaque_materializers: Mapping[str, Callable[[dict[str, Any]], Any]] | None = None, +) -> dict[str, Any]: + """ + Rebuild live opaque instances from their stored identifier payloads. + + For each OPAQUE role in ``cls.input_schema()``: + + * If the stored value is already a live instance (not a dict), pass through. + * If the stored value is a dict and a callable is registered in + ``opaque_materializers[role.name]``, invoke it with the dict. + * Otherwise raise :class:`OpaqueInputUnresolvedError`. + + SCALAR / CHOICE roles pass through unchanged. + + Args: + cls: The scenario class. + init_inputs: Stored init inputs from the artifact. + opaque_materializers: Optional mapping from role name to a callable + that consumes the stored identifier dict and returns a live instance. + + Returns: + dict[str, Any]: Materialized init inputs ready for + :func:`build_scenario_from_inputs`. + + Raises: + OpaqueInputUnresolvedError: When an opaque role has a dict payload but + no materializer is provided. + """ + materializers = dict(opaque_materializers or {}) + schema = {role.name: role for role in cls.input_schema()} + out: dict[str, Any] = {} + for name, value in init_inputs.items(): + role = schema.get(name) + if role is not None and role.tag is RoleTag.OPAQUE and isinstance(value, dict): + if name not in materializers: + raise OpaqueInputUnresolvedError(name, value) + out[name] = materializers[name](value) + else: + out[name] = value + return out + + +async def load_scenario_from_artifact( + artifact: GraphArtifact, + *, + objective_target: PromptTarget, + allow_drift: bool = False, + opaque_materializers: Mapping[str, Callable[[dict[str, Any]], Any]] | None = None, +) -> Scenario: + """ + Rebuild and initialize a scenario from a :class:`GraphArtifact`. + + Args: + artifact: The artifact (typically loaded via :func:`graph_artifact_from_yaml`). + objective_target: The target to run the scenario against. NOT captured + in the artifact (it's environment-specific and frequently opaque); + always required at load time. + allow_drift: When ``True``, version + topology-hash mismatches are + logged but not fatal. Default ``False`` mirrors the strict-fail + resume contract on :class:`Scenario`. + opaque_materializers: Per-role-name callables for rebuilding opaque + ``init_inputs`` from their stored identifier payloads. See + :func:`materialize_opaque_inputs`. + + Returns: + Scenario: A fully initialized scenario equivalent (modulo drift) to the + one that produced the artifact. + + Raises: + GraphArtifactSecurityError: If the captured FQN is not registry-whitelisted. + GraphArtifactDriftError: On version or topology-hash mismatch when + ``allow_drift=False``. + OpaqueInputUnresolvedError: If an opaque role has no materializer. + """ + cls = _resolve_scenario_fqn(artifact.scenario_class_fqn) + + if cls.__name__ != artifact.scenario_class_fqn.rsplit(".", 1)[-1] and not allow_drift: + raise GraphArtifactDriftError( + f"Class name mismatch: artifact claims {artifact.scenario_class_fqn!r} but resolved {cls.__name__!r}." + ) + + materialized_init_inputs = materialize_opaque_inputs( + cls, + artifact.init_inputs, + opaque_materializers=opaque_materializers, + ) + + strategy_cls = cls.get_strategy_class() + rebuilt_strategies = [strategy_cls[name] for name in artifact.scenario_strategies] + dataset_config = _deserialize_dataset_config(artifact.dataset_config) + + init_async = dict(artifact.init_async_inputs) + init_async.setdefault("objective_target", objective_target) + init_async.setdefault("scenario_strategies", rebuilt_strategies) + if dataset_config is not None: + init_async.setdefault("dataset_config", dataset_config) + init_async.setdefault("include_baseline", artifact.include_baseline) + if artifact.memory_labels: + init_async.setdefault("memory_labels", artifact.memory_labels) + + scenario = await build_scenario_from_inputs( + cls, + init_inputs=materialized_init_inputs, + init_async_inputs=init_async, + ) + + rebuilt_summary = build_topology_summary(scenario) + rebuilt_hash = _topology_hash(rebuilt_summary) + if rebuilt_hash != artifact.topology_hash and not allow_drift: + raise GraphArtifactDriftError( + f"Topology hash mismatch after rebuild. " + f"Artifact: {artifact.topology_hash}, rebuilt: {rebuilt_hash}. " + "Set allow_drift=True to bypass." + ) + + return scenario diff --git a/tests/unit/registry/test_discovery_resilience.py b/tests/unit/registry/test_discovery_resilience.py new file mode 100644 index 000000000..581c0c6d6 --- /dev/null +++ b/tests/unit/registry/test_discovery_resilience.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Regression coverage for ``discover_in_package`` resilience to non-class type aliases.""" + +from __future__ import annotations + +import textwrap +from typing import TYPE_CHECKING + +import pytest + +from pyrit.registry.discovery import discover_in_package + +if TYPE_CHECKING: + from pathlib import Path + + +def _write_module(path: Path, body: str) -> None: + path.write_text(textwrap.dedent(body), encoding="utf-8") + + +@pytest.fixture +def poisoned_package(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> tuple[Path, str, type]: + """ + Build a synthetic package with a module that exposes a parameterized + ``Callable`` alias alongside a real concrete subclass of a synthetic base. + The base class lives inside the fixture package so the synthetic module + can import it without depending on the ``tests`` namespace being importable. + """ + pkg_root = tmp_path / "_discovery_pkg" + pkg_root.mkdir() + _write_module( + pkg_root / "__init__.py", + """ + class DiscoveryBase: + pass + """, + ) + + _write_module( + pkg_root / "good_module.py", + """ + from collections.abc import Callable + from _discovery_pkg import DiscoveryBase + + # Parameterized Callable type alias — inspect.isclass(alias) is True + # but issubclass(alias, anything) raises TypeError. Before the fix this + # poisoned the rest of the module's discovery. + Poisoned = Callable[[int], str] + + class _RealConcreteAfter(DiscoveryBase): + pass + """, + ) + + monkeypatch.syspath_prepend(str(tmp_path)) + import importlib + + pkg = importlib.import_module("_discovery_pkg") + return pkg_root, "_discovery_pkg", pkg.DiscoveryBase + + +def test_discovery_skips_non_class_aliases(poisoned_package): + pkg_root, pkg_name, base_cls = poisoned_package + discovered = list( + discover_in_package( + package_path=pkg_root, + package_name=pkg_name, + base_class=base_cls, + recursive=False, + ) + ) + discovered_names = {cls.__name__ for _, cls in discovered} + assert "_RealConcreteAfter" in discovered_names + + +def test_text_adaptive_registers_in_scenario_registry(): + """ + End-to-end regression: ``TextAdaptive`` lives in a module that exposes a + parameterized Callable alias (``ContextExtractor``) at module scope. It + must still appear in the registry after discovery. + """ + from pyrit.registry.class_registries.scenario_registry import ScenarioRegistry + + registry = ScenarioRegistry() + registry._ensure_discovered() + names = sorted(registry._class_entries.keys()) + assert "adaptive.text_adaptive" in names diff --git a/tests/unit/scenario/core/test_graph_artifact.py b/tests/unit/scenario/core/test_graph_artifact.py new file mode 100644 index 000000000..297ffc6a7 --- /dev/null +++ b/tests/unit/scenario/core/test_graph_artifact.py @@ -0,0 +1,485 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Phase 8g — coverage for ``graph_artifact`` build / serialize / load primitives.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +import yaml + +from pyrit.scenario.core.graph_artifact import ( + GraphArtifact, + GraphArtifactDriftError, + GraphArtifactError, + GraphArtifactSecurityError, + OpaqueInputUnresolvedError, + _canonical_json, + _class_fqn, + _deserialize_dataset_config, + _encode_init_inputs, + _resolve_scenario_fqn, + _serialize_dataset_config, + _topology_hash, + build_graph_artifact, + build_topology_summary, + graph_artifact_from_yaml, + graph_artifact_to_yaml, + materialize_opaque_inputs, +) +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + +# --- fake scenario surface ------------------------------------------------------ +# +# build_graph_artifact reads a small slice of the Scenario lifecycle. We model +# that slice with a real class so ``type(scenario)`` round-trips through +# ``_class_fqn`` and class methods (``input_schema``, ``get_strategy_class``) +# work without monkeying with ``__class__``. + + +class _FakeStrategyEnum: + """Stand-in for a ScenarioStrategy enum class — only needs an FQN.""" + + +class _FakeScenario: + """Minimal Scenario-shaped object for graph_artifact unit tests.""" + + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return [] + + @classmethod + def get_strategy_class(cls) -> type: + return _FakeStrategyEnum + + +def _fake_atomic(*, name: str, hash_value: str = "abc123") -> MagicMock: + """A mock that satisfies ``atomic.get_identifier().to_dict()``.""" + atomic = MagicMock() + identifier = MagicMock() + identifier.to_dict.return_value = {"name": name, "hash": hash_value} + atomic.get_identifier.return_value = identifier + return atomic + + +def _fake_initialized_scenario( + *, + version: int = 1, + atomic_names: tuple[str, ...] = ("step_a", "step_b"), + states: tuple[str, ...] = ("STATE_0", "STATE_1"), + initial_state: str = "STATE_0", + terminal_states: tuple[str, ...] = ("STATE_DONE",), + strategies: tuple[str, ...] = ("EASY",), + dataset_names: list[str] | None = None, + max_dataset_size: int | None = None, + include_baseline: bool = False, + params: dict[str, Any] | None = None, + memory_labels: dict[str, str] | None = None, +) -> _FakeScenario: + """ + A lightly-mocked stand-in for a fully-initialized ``Scenario``. + + The graph artifact build path only reads a small surface area — we mirror + that surface without spinning up the real lifecycle. + """ + scenario = _FakeScenario() + + identifier = MagicMock() + identifier.version = version + scenario._identifier = identifier # type: ignore[attr-defined] + + strategy_objs = [] + for name in strategies: + s = MagicMock() + s.name = name + strategy_objs.append(s) + scenario._scenario_strategies = strategy_objs # type: ignore[attr-defined] + + cfg = MagicMock() + cfg._dataset_names = dataset_names + cfg.max_dataset_size = max_dataset_size + cfg._seed_groups = None + scenario._dataset_config = cfg # type: ignore[attr-defined] + + scenario._include_baseline = include_baseline # type: ignore[attr-defined] + scenario.params = params or {} # type: ignore[attr-defined] + scenario._memory_labels = memory_labels or {} # type: ignore[attr-defined] + + scenario._atomic_attacks = [_fake_atomic(name=n) for n in atomic_names] # type: ignore[attr-defined] + + policy = MagicMock() + policy.actions = {s: (lambda _g: None) for s in states} + policy.initial_state = initial_state + policy.terminal_states = frozenset(terminal_states) + graph = MagicMock() + graph.policy = policy + scenario.execution_graph = graph # type: ignore[attr-defined] + + return scenario + + +# --- _canonical_json ------------------------------------------------------------- + + +class TestCanonicalJson: + def test_sorts_keys(self): + assert _canonical_json({"b": 1, "a": 2}) == '{"a":2,"b":1}' + + def test_nested_sort(self): + assert _canonical_json({"b": {"y": 1, "x": 2}}) == '{"b":{"x":2,"y":1}}' + + def test_handles_unserializable_via_str(self): + class _X: + def __repr__(self) -> str: + return "" + + assert _canonical_json({"obj": _X()}) == '{"obj":""}' + + +# --- _class_fqn ------------------------------------------------------------------ + + +class TestClassFqn: + def test_returns_module_dot_qualname(self): + # Use the module-level _FakeScenario so we get a clean, predictable qualname. + fqn = _class_fqn(_FakeScenario) + assert fqn.endswith("._FakeScenario") + assert "." in fqn # must be a dotted FQN, not a bare class name + + +# --- _resolve_scenario_fqn ------------------------------------------------------- + + +class TestResolveScenarioFqn: + def test_resolves_registered_scenario(self): + cls = _resolve_scenario_fqn("pyrit.scenario.scenarios.garak.encoding.Encoding") + assert cls.__name__ == "Encoding" + + def test_rejects_non_dotted_fqn(self): + with pytest.raises(GraphArtifactSecurityError, match="not a dotted path"): + _resolve_scenario_fqn("AdaptiveScenario") + + def test_rejects_unimportable_module(self): + with pytest.raises(GraphArtifactSecurityError): + _resolve_scenario_fqn("nonexistent.module.path.SomeClass") + + def test_rejects_missing_attribute(self): + with pytest.raises(GraphArtifactSecurityError): + _resolve_scenario_fqn("pyrit.scenario.core.scenario.NotAClass") + + def test_rejects_non_scenario_subclass(self): + with pytest.raises(GraphArtifactSecurityError, match="not a Scenario subclass"): + _resolve_scenario_fqn("pyrit.scenario.core.input_schema.RoleDescriptor") + + def test_rejects_unregistered_scenario_subclass(self): + """A real Scenario subclass that's not registry-discoverable must be rejected.""" + + # Define a private Scenario subclass at module load time — the registry + # only discovers scenarios in pyrit.scenario.scenarios.*, so this should + # be rejected even though it IS a Scenario subclass. + # We can't easily inject one without polluting the registry, so we use + # the abstract `Scenario` itself which is not registered. + with pytest.raises(GraphArtifactSecurityError, match="not in the registry whitelist"): + _resolve_scenario_fqn("pyrit.scenario.core.scenario.Scenario") + + +# --- DatasetConfiguration round-trip -------------------------------------------- + + +class TestDatasetConfigSerialize: + def test_serializes_dataset_names(self): + cfg = MagicMock() + cfg._dataset_names = ["xstest"] + cfg.max_dataset_size = 10 + cfg._seed_groups = None + out = _serialize_dataset_config(cfg) + assert out == {"dataset_names": ["xstest"], "max_dataset_size": 10, "explicit_seed_groups_count": 0} + + def test_serializes_none_dataset_names(self): + cfg = MagicMock() + cfg._dataset_names = None + cfg.max_dataset_size = None + cfg._seed_groups = None + out = _serialize_dataset_config(cfg) + assert out["dataset_names"] is None + assert out["max_dataset_size"] is None + + def test_records_explicit_seed_group_count(self): + cfg = MagicMock() + cfg._dataset_names = None + cfg.max_dataset_size = None + cfg._seed_groups = [MagicMock(), MagicMock(), MagicMock()] + out = _serialize_dataset_config(cfg) + assert out["explicit_seed_groups_count"] == 3 + + +class TestDatasetConfigDeserialize: + def test_round_trips_dataset_names(self): + cfg = _deserialize_dataset_config( + {"dataset_names": ["xstest"], "max_dataset_size": 10, "explicit_seed_groups_count": 0} + ) + assert cfg is not None + assert cfg._dataset_names == ["xstest"] + assert cfg.max_dataset_size == 10 + + def test_empty_payload_returns_none(self): + assert _deserialize_dataset_config({}) is None + + def test_explicit_seed_groups_raises(self): + with pytest.raises(GraphArtifactError, match="explicit seed_groups"): + _deserialize_dataset_config({"explicit_seed_groups_count": 2}) + + +# --- _topology_hash -------------------------------------------------------------- + + +class TestTopologyHash: + def test_deterministic_across_calls(self): + summary = {"states": ["A", "B"], "initial_state": "A"} + assert _topology_hash(summary) == _topology_hash(summary) + + def test_changes_when_summary_changes(self): + a = _topology_hash({"states": ["A"]}) + b = _topology_hash({"states": ["A", "B"]}) + assert a != b + + +# --- build_topology_summary ------------------------------------------------------ + + +class TestBuildTopologySummary: + def test_collects_states_atoms_and_terminals(self): + scenario = _fake_initialized_scenario(atomic_names=("step_x",), states=("S0",), terminal_states=("DONE",)) + summary = build_topology_summary(scenario) + assert summary["states"] == ["S0"] + assert summary["initial_state"] == "STATE_0" + assert summary["terminal_states"] == ["DONE"] + assert len(summary["atomic_attacks"]) == 1 + assert summary["atomic_attacks"][0] == {"name": "step_x", "hash": "abc123"} + + def test_sorted_state_lists_for_determinism(self): + scenario = _fake_initialized_scenario(states=("S_B", "S_A", "S_C"), terminal_states=("T_B", "T_A")) + summary = build_topology_summary(scenario) + assert summary["states"] == ["S_A", "S_B", "S_C"] + assert summary["terminal_states"] == ["T_A", "T_B"] + + def test_includes_scenario_class_and_version(self): + scenario = _fake_initialized_scenario(version=7) + summary = build_topology_summary(scenario) + assert summary["scenario_version"] == 7 + assert "scenario_class_fqn" in summary + + +# --- _encode_init_inputs --------------------------------------------------------- + + +class TestEncodeInitInputs: + def test_scalar_passes_through(self): + schema = [RoleDescriptor(name="x", description="d", tag=RoleTag.SCALAR, param_type=str)] + out = _encode_init_inputs(schema=schema, init_inputs={"x": "hello"}) + assert out == {"x": "hello"} + + def test_opaque_value_snapshots_via_identifier(self): + instance = MagicMock() + instance.get_identifier.return_value.to_dict.return_value = {"cls": "Atomic", "hash": "h1"} + schema = [RoleDescriptor(name="atom", description="d", tag=RoleTag.OPAQUE)] + out = _encode_init_inputs(schema=schema, init_inputs={"atom": instance}) + assert out == {"atom": {"cls": "Atomic", "hash": "h1"}} + + def test_opaque_value_without_identifier_passes_through(self): + instance = object() # no get_identifier attribute + schema = [RoleDescriptor(name="atom", description="d", tag=RoleTag.OPAQUE)] + out = _encode_init_inputs(schema=schema, init_inputs={"atom": instance}) + assert out["atom"] is instance + + def test_opaque_none_passes_through(self): + schema = [RoleDescriptor(name="atom", description="d", tag=RoleTag.OPAQUE, required=False, default="x")] + out = _encode_init_inputs(schema=schema, init_inputs={"atom": None}) + assert out["atom"] is None + + def test_unknown_input_passes_through(self): + out = _encode_init_inputs(schema=[], init_inputs={"unknown_kwarg": 5}) + assert out == {"unknown_kwarg": 5} + + +# --- build_graph_artifact -------------------------------------------------------- + + +class TestBuildGraphArtifact: + def test_populates_all_fields(self): + scenario = _fake_initialized_scenario( + version=3, + strategies=("EASY", "HARD"), + dataset_names=["xstest"], + max_dataset_size=42, + include_baseline=True, + params={"alpha": 0.5}, + memory_labels={"run": "abc"}, + ) + artifact = build_graph_artifact(scenario) + assert artifact.scenario_class_fqn.endswith("._FakeScenario") + assert artifact.scenario_version == 3 + assert artifact.scenario_strategies == ["EASY", "HARD"] + assert artifact.dataset_config["dataset_names"] == ["xstest"] + assert artifact.dataset_config["max_dataset_size"] == 42 + assert artifact.include_baseline is True + assert artifact.params == {"alpha": 0.5} + assert artifact.memory_labels == {"run": "abc"} + assert artifact.topology_hash != "" + assert artifact.state_enum_fqn is not None + assert artifact.state_enum_fqn.endswith("._FakeStrategyEnum") + + def test_init_async_inputs_default_to_params(self): + scenario = _fake_initialized_scenario(params={"alpha": 1.0}) + artifact = build_graph_artifact(scenario) + assert artifact.init_async_inputs == {"alpha": 1.0} + + def test_explicit_init_async_inputs_override_params(self): + scenario = _fake_initialized_scenario(params={"alpha": 1.0}) + artifact = build_graph_artifact(scenario, init_async_inputs={"max_concurrency": 4}) + assert artifact.init_async_inputs == {"max_concurrency": 4} + + def test_topology_hash_stable_across_builds(self): + s1 = _fake_initialized_scenario() + s2 = _fake_initialized_scenario() + assert build_graph_artifact(s1).topology_hash == build_graph_artifact(s2).topology_hash + + +# --- YAML round-trip ------------------------------------------------------------- + + +class TestYamlRoundTrip: + def test_to_and_from_yaml(self, tmp_path): + scenario = _fake_initialized_scenario(version=5) + artifact = build_graph_artifact(scenario) + path = tmp_path / "artifact.yaml" + graph_artifact_to_yaml(artifact, path) + loaded = graph_artifact_from_yaml(path) + # field-for-field equivalence except set/frozenset reshape (none in artifact). + assert loaded == artifact + + def test_byte_identical_for_equal_artifacts(self, tmp_path): + s1 = _fake_initialized_scenario(params={"a": 1, "b": 2}) + s2 = _fake_initialized_scenario(params={"b": 2, "a": 1}) + a1 = build_graph_artifact(s1) + a2 = build_graph_artifact(s2) + p1 = tmp_path / "a1.yaml" + p2 = tmp_path / "a2.yaml" + graph_artifact_to_yaml(a1, p1) + graph_artifact_to_yaml(a2, p2) + assert p1.read_bytes() == p2.read_bytes() + + def test_yaml_payload_is_sort_key_canonical(self, tmp_path): + scenario = _fake_initialized_scenario() + artifact = build_graph_artifact(scenario) + path = tmp_path / "a.yaml" + graph_artifact_to_yaml(artifact, path) + text = path.read_text(encoding="utf-8") + # First top-level key should be alphabetically smallest, i.e. "artifact_version". + first_key = text.splitlines()[0].split(":", 1)[0] + assert first_key == "artifact_version" + + def test_from_yaml_rejects_wrong_artifact_version(self, tmp_path): + path = tmp_path / "stale.yaml" + # Build a valid artifact then rewrite its artifact_version to a stale value. + scenario = _fake_initialized_scenario() + artifact = build_graph_artifact(scenario) + graph_artifact_to_yaml(artifact, path) + payload = yaml.safe_load(path.read_text(encoding="utf-8")) + payload["artifact_version"] = 99 + path.write_text(yaml.safe_dump(payload, sort_keys=True), encoding="utf-8") + with pytest.raises(GraphArtifactError, match="artifact_version"): + graph_artifact_from_yaml(path) + + def test_from_yaml_rejects_non_mapping(self, tmp_path): + path = tmp_path / "bad.yaml" + path.write_text("- not a mapping", encoding="utf-8") + with pytest.raises(GraphArtifactError, match="not a mapping"): + graph_artifact_from_yaml(path) + + +# --- materialize_opaque_inputs --------------------------------------------------- + + +class _DummyScenarioCls: + """Provides ``input_schema`` for ``materialize_opaque_inputs`` tests.""" + + _schema: list[RoleDescriptor] = [] + + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return cls._schema + + +class TestMaterializeOpaqueInputs: + def test_scalar_passes_through(self): + _DummyScenarioCls._schema = [ + RoleDescriptor(name="alpha", description="d", tag=RoleTag.SCALAR, param_type=float) + ] + out = materialize_opaque_inputs(_DummyScenarioCls, {"alpha": 0.5}) # type: ignore[arg-type] + assert out == {"alpha": 0.5} + + def test_opaque_live_instance_passes_through(self): + instance = MagicMock() + _DummyScenarioCls._schema = [RoleDescriptor(name="atom", description="d", tag=RoleTag.OPAQUE)] + out = materialize_opaque_inputs(_DummyScenarioCls, {"atom": instance}) # type: ignore[arg-type] + assert out["atom"] is instance + + def test_opaque_dict_with_materializer_rebuilds(self): + rebuilt = object() + _DummyScenarioCls._schema = [RoleDescriptor(name="atom", description="d", tag=RoleTag.OPAQUE)] + out = materialize_opaque_inputs( + _DummyScenarioCls, # type: ignore[arg-type] + {"atom": {"hash": "h"}}, + opaque_materializers={"atom": lambda payload: rebuilt}, + ) + assert out["atom"] is rebuilt + + def test_opaque_dict_without_materializer_raises(self): + _DummyScenarioCls._schema = [RoleDescriptor(name="atom", description="d", tag=RoleTag.OPAQUE)] + with pytest.raises(OpaqueInputUnresolvedError) as exc: + materialize_opaque_inputs(_DummyScenarioCls, {"atom": {"hash": "h"}}) # type: ignore[arg-type] + assert exc.value.role_name == "atom" + assert "opaque_materializers" in str(exc.value) + + +# --- GraphArtifact dataclass shape ---------------------------------------------- + + +class TestGraphArtifactDataclass: + def test_frozen(self): + artifact = GraphArtifact(scenario_class_fqn="x.Y", scenario_version=1, pyrit_version="0.0.0") + with pytest.raises(Exception): + artifact.scenario_version = 2 # type: ignore[misc] + + def test_default_field_values(self): + artifact = GraphArtifact(scenario_class_fqn="x.Y", scenario_version=1, pyrit_version="0.0.0") + assert artifact.init_inputs == {} + assert artifact.scenario_strategies == [] + assert artifact.include_baseline is False + assert artifact.topology_hash == "" + assert artifact.state_enum_fqn is None + assert artifact.artifact_version == 1 + + +# --- Error-type shape ------------------------------------------------------------ + + +class TestErrorTypes: + def test_security_error_is_artifact_error(self): + assert issubclass(GraphArtifactSecurityError, GraphArtifactError) + + def test_drift_error_is_artifact_error(self): + assert issubclass(GraphArtifactDriftError, GraphArtifactError) + + def test_opaque_unresolved_is_artifact_error(self): + assert issubclass(OpaqueInputUnresolvedError, GraphArtifactError) + + def test_opaque_unresolved_carries_role_and_payload(self): + payload = {"k": 1} + exc = OpaqueInputUnresolvedError("alpha", payload) + assert exc.role_name == "alpha" + assert exc.payload is payload From 1a992f8cdfb4fe6fd4dd412016c166bdae86c919 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 15:25:04 -0700 Subject: [PATCH 22/42] opt BroadSweepThenDeepDive out of the technique registry inspection path Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenarios/airt/sweep_then_deep_dive.py | 18 +++++++++ .../airt/test_sweep_then_deep_dive.py | 40 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index ee463e5d6..6b58a04e1 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -57,6 +57,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import AttackResult from pyrit.scenario.core.atomic_attack import AtomicAttack + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.score import TrueFalseScorer logger = logging.getLogger(__name__) @@ -605,6 +606,23 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ return [self._sweep_atomic, *self._deep_dive_atomics] + def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: + """ + Return an empty factory map: this scenario does not use the technique registry. + + The sweep and deep-dive atomics are supplied directly via the constructor, + so factory lookup never happens during execution. Overriding the base + method (which lazily populates the global ``AttackTechniqueRegistry`` + singleton via ``register_scenario_techniques``) keeps introspection + (e.g. :func:`pyrit.scenario.core.waterfall.policy_to_spec`) side-effect-free + and makes the policy-parameterized intent explicit, mirroring the + ``BASELINE_ATTACK_POLICY = Forbidden`` declaration above. + + Returns: + dict[str, AttackTechniqueFactory]: Always empty. + """ + return {} + def _build_execution_graph( # ty: ignore[invalid-method-override] self, *, diff --git a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py index becd8ee9c..17437d1d7 100644 --- a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py +++ b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py @@ -516,6 +516,46 @@ def test_baseline_attack_policy_is_forbidden(self) -> None: # if a caller passes ``include_baseline=True``. assert BroadSweepThenDeepDive.BASELINE_ATTACK_POLICY is BaselineAttackPolicy.Forbidden + def test_get_attack_technique_factories_returns_empty_dict(self) -> None: + # The sweep and deep-dive atomics are supplied directly via the + # constructor — no registry lookup ever happens during execution. + # Overriding the base method (which would lazily populate the global + # AttackTechniqueRegistry singleton via register_scenario_techniques) + # keeps introspection by waterfall.policy_to_spec side-effect-free. + scenario, _, _ = self._build_scenario( + sweep_response_text="safe", + scorer_label_for={"safe": _SAFE_LABEL}, + deep_dive_display_groups=["cat-a"], + ) + assert scenario._get_attack_technique_factories() == {} + + def test_get_attack_technique_factories_does_not_mutate_global_registry(self) -> None: + # Removing the BSTDDive override would silently re-introduce the + # ``register_scenario_techniques()`` side-effect on the global + # ``AttackTechniqueRegistry`` singleton. Pin the absence of mutation + # so future refactors that drop the override fail this test. + from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry + + AttackTechniqueRegistry.reset_instance() + try: + before = set(AttackTechniqueRegistry.get_registry_singleton().get_factories().keys()) + + scenario, _, _ = self._build_scenario( + sweep_response_text="safe", + scorer_label_for={"safe": _SAFE_LABEL}, + deep_dive_display_groups=["cat-a"], + ) + scenario._get_attack_technique_factories() + + after = set(AttackTechniqueRegistry.get_registry_singleton().get_factories().keys()) + assert after == before, ( + f"BroadSweepThenDeepDive._get_attack_technique_factories() mutated the global " + f"AttackTechniqueRegistry singleton (added {sorted(after - before)}). The override " + f"must return {{}} without invoking register_scenario_techniques()." + ) + finally: + AttackTechniqueRegistry.reset_instance() + async def test_get_atomic_attacks_returns_canonical_order(self) -> None: scenario, sweep_atomic, deep_dives = self._build_scenario( sweep_response_text="safe", From e0f011bde7834175f45e9d062db97825a6081b01 Mon Sep 17 00:00:00 2001 From: ValbuenaVC Date: Wed, 20 May 2026 15:33:19 -0700 Subject: [PATCH 23/42] surface AdversarialBenchmark's benchmarkable specs from the technique-factory catalog Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenarios/benchmark/adversarial.py | 42 +++++++++++ tests/unit/scenario/test_adversarial.py | 75 +++++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index dfec12839..33d8d5d8f 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from pyrit.prompt_target import PromptTarget + from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueFactory from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.score import TrueFalseScorer @@ -294,3 +295,44 @@ def _get_benchmarkable_specs() -> list[AttackTechniqueSpec]: for spec in SCENARIO_TECHNIQUES if AttackTechniqueRegistry._accepts_adversarial(spec.attack_class) and spec.adversarial_chat is None ] + + def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: + """ + Return locally-built factories backed by the benchmarkable specs. + + AdversarialBenchmark sweeps user-supplied ``adversarial_models`` rather + than using a single baked-in chat; its specs deliberately have + ``adversarial_chat = None`` and the per-model + :class:`AttackAdversarialConfig` is injected at attack-create time via + :attr:`_adversarial_configs`. + + The base implementation would dispatch to the global + :class:`AttackTechniqueRegistry` singleton, which carries the versions + of these specs that ``register_scenario_techniques`` has rewritten with + a default adversarial chat — a stale view that does not reflect what + this scenario actually runs. Worse, that path also triggers global + registry mutation as a side effect of being inspected. + + Returning factories built directly from + :meth:`_get_benchmarkable_specs` ensures that: + + * :func:`pyrit.scenario.core.waterfall.policy_to_spec` surfaces the + benchmarkable (no-chat) specs that the wizard / configurator can + round-trip without absorbing the default adversarial chat. + * No global registry registration occurs as a side effect of the + inspection. + + The returned factories are not used to execute attacks (that path runs + through :meth:`_get_atomic_attacks_async` with its own local factory + construction); they exist purely as the registry-shaped catalog of + techniques this scenario uses. + + Returns: + dict[str, AttackTechniqueFactory]: ``spec.name -> factory`` for + every benchmarkable spec; each factory carries the benchmarkable + spec (``adversarial_chat is None``) via ``source_spec``. + """ + return { + spec.name: AttackTechniqueRegistry.build_factory_from_spec(spec) + for spec in AdversarialBenchmark._get_benchmarkable_specs() + } diff --git a/tests/unit/scenario/test_adversarial.py b/tests/unit/scenario/test_adversarial.py index 5914a40ba..bf36692a4 100644 --- a/tests/unit/scenario/test_adversarial.py +++ b/tests/unit/scenario/test_adversarial.py @@ -25,6 +25,7 @@ from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy from pyrit.scenario.core.dataset_configuration import DatasetConfiguration from pyrit.scenario.core.scenario_techniques import SCENARIO_TECHNIQUES +from pyrit.scenario.core.waterfall import policy_to_spec from pyrit.scenario.scenarios.benchmark.adversarial import AdversarialBenchmark from pyrit.score import TrueFalseScorer @@ -617,3 +618,77 @@ def test_per_model_breakdown_reflects_outcome_counts(self): # Display grouping must not lose results. assert sum(len(rs) for rs in groups.values()) == sum(len(rs) for rs in attack_results.values()) + + +# =========================================================================== +# Phase 8a waterfall integration tests +# =========================================================================== + + +@pytest.mark.usefixtures(*FIXTURES) +class TestBenchmarkAttackTechniqueFactories: + """Pin AdversarialBenchmark._get_attack_technique_factories semantics. + + AdversarialBenchmark sweeps user-provided ``adversarial_models`` and injects + the chat per-attack at create-time. The factories it advertises through + ``_get_attack_technique_factories`` (consumed by ``policy_to_spec``) must + therefore reflect the *benchmarkable* specs (``adversarial_chat is None``), + NOT the global registry's replaced versions where + ``register_scenario_techniques`` has baked in a default adversarial chat. + """ + + def test_factory_keys_are_exactly_benchmarkable_spec_names(self, single_adversarial_model): + scenario = _make_benchmark(single_adversarial_model) + factories = scenario._get_attack_technique_factories() + expected = {spec.name for spec in AdversarialBenchmark._get_benchmarkable_specs()} + assert set(factories.keys()) == expected + + def test_factory_source_specs_have_no_baked_adversarial_chat(self, single_adversarial_model): + scenario = _make_benchmark(single_adversarial_model) + factories = scenario._get_attack_technique_factories() + for name, factory in factories.items(): + assert factory.source_spec is not None, f"factory {name!r} has no source_spec" + assert factory.source_spec.adversarial_chat is None, ( + f"factory {name!r} carries a baked-in adversarial_chat; AdversarialBenchmark " + "must surface specs whose chat is supplied per-model at create-time" + ) + + def test_factory_source_specs_match_benchmarkable_specs(self, single_adversarial_model): + scenario = _make_benchmark(single_adversarial_model) + factories = scenario._get_attack_technique_factories() + benchmarkable_by_name = {spec.name: spec for spec in AdversarialBenchmark._get_benchmarkable_specs()} + for name, factory in factories.items(): + assert factory.source_spec == benchmarkable_by_name[name] + + def test_get_attack_technique_factories_does_not_mutate_global_registry(self, single_adversarial_model): + before = set(AttackTechniqueRegistry.get_registry_singleton().get_factories().keys()) + scenario = _make_benchmark(single_adversarial_model) + scenario._get_attack_technique_factories() + after = set(AttackTechniqueRegistry.get_registry_singleton().get_factories().keys()) + assert before == after, ( + "AdversarialBenchmark._get_attack_technique_factories must not register " + "anything into the global registry singleton" + ) + + async def test_policy_to_spec_returns_benchmarkable_specs(self, mock_objective_target, single_adversarial_model): + """policy_to_spec must surface the specs the scenario actually uses.""" + groups = {"harmbench": _make_seed_groups("harmbench")} + with ( + patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") as mock_scorer, + ): + mock_scorer.return_value = MagicMock(spec=TrueFalseScorer, get_identifier=lambda: _mock_id("scorer")) + scenario = AdversarialBenchmark(adversarial_models=single_adversarial_model) + all_strat = scenario._strategy_class("all") + await scenario.initialize_async(objective_target=mock_objective_target, scenario_strategies=[all_strat]) + + specs = policy_to_spec(scenario) + assert specs, "policy_to_spec returned empty list for an initialized AdversarialBenchmark" + for spec in specs: + assert spec.adversarial_chat is None, ( + f"policy_to_spec returned spec {spec.name!r} with a baked-in adversarial_chat; " + "the wizard would mis-reconstruct the scenario" + ) + returned_names = {spec.name for spec in specs} + expected_names = {spec.name for spec in AdversarialBenchmark._get_benchmarkable_specs()} + assert returned_names == expected_names From 5e28feea315c831c3e02675310d53c92fe0a9d25 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:18:10 -0700 Subject: [PATCH 24/42] validate weakness_label at BroadSweepThenDeepDive constructor Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenarios/airt/sweep_then_deep_dive.py | 15 ++++++++++++ .../airt/test_sweep_then_deep_dive.py | 23 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index 6b58a04e1..8f072185f 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -492,10 +492,25 @@ def __init__( Raises: ValueError: If ``deep_dive_atomic_attacks`` is empty. + ValueError: If ``weakness_label`` is not declared as one of + ``outcome_scorer.outcomes``. """ if not deep_dive_atomic_attacks: raise ValueError("BroadSweepThenDeepDive requires at least one deep_dive_atomic_attack.") + # Fail fast: the inner ``CategoryAggregatingSweepStep`` performs the + # same check, but only inside ``_build_execution_graph`` (called from + # ``run_async``). A wizard-built scenario would otherwise serialize a + # graph artifact and only error at first execution, defeating the + # "validate at the surface that elicited the input" guarantee that + # ``input_schema`` advertises for the ``weakness_label`` / ``outcome_scorer`` + # pair. + if weakness_label not in outcome_scorer.outcomes: + raise ValueError( + f"weakness_label {weakness_label!r} is not declared as an outcome of the " + f"supplied OutcomeScorer (declared: {outcome_scorer.outcomes!r})." + ) + self._sweep_atomic = sweep_atomic_attack self._deep_dive_atomics: list[AtomicAttack] = list(deep_dive_atomic_attacks) self._outcome_scorer = outcome_scorer diff --git a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py index 17437d1d7..a0854d871 100644 --- a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py +++ b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py @@ -504,6 +504,29 @@ def test_constructor_rejects_empty_deep_dive_list(self) -> None: outcome_scorer=scorer, ) + def test_constructor_rejects_weakness_label_not_in_outcome_scorer(self) -> None: + # The inner ``CategoryAggregatingSweepStep.__init__`` already validates + # this, but it only runs when ``_build_execution_graph`` is invoked at + # ``run_async`` time. A wizard-built scenario would happily save a + # graph artifact and only blow up the first time the user actually + # ran it. Fail-fast at outer-constructor time so the wizard / API + # caller gets the error at the same point they supplied the input. + wrapped = MagicMock(spec=Scorer) + wrapped.get_identifier.return_value = _make_scorer_id("MockScorer") + scorer = OutcomeScorer( + wrapped_scorer=wrapped, + outcome_map={_SAFE_LABEL: lambda s: True}, # no weakness label + ) + sweep = _make_atomic_mock(name="sweep", display_group="cat-a", attack_results=[]) + deep = _make_atomic_mock(name="deep", display_group="cat-a", attack_results=[]) + with pytest.raises(ValueError, match=r"weakness_label .* not declared"): + BroadSweepThenDeepDive( + sweep_atomic_attack=cast("AtomicAttack", sweep), + deep_dive_atomic_attacks=[cast("AtomicAttack", deep)], + outcome_scorer=scorer, + weakness_label="never_emitted", + ) + def test_strategy_metadata(self) -> None: assert BroadSweepThenDeepDive.get_strategy_class() is BroadSweepThenDeepDiveStrategy assert BroadSweepThenDeepDive.get_default_strategy() is BroadSweepThenDeepDiveStrategy.DEFAULT From 66d496c3e7bc41b3b2d857561a7a6adfb720847c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:22:20 -0700 Subject: [PATCH 25/42] validate epsilon, pool_threshold, max_attempts at AdaptiveScenario constructor Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenarios/adaptive/adaptive_scenario.py | 18 +++++++++++ .../scenarios/adaptive/test_text_adaptive.py | 32 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py index cc8729153..19a604365 100644 --- a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -90,7 +90,25 @@ def __init__( context_extractor (ContextExtractor): Maps a ``SeedAttackGroup`` to a context key. Defaults to ``global_context``. scenario_result_id (str | None): ID of an existing ``ScenarioResult`` to resume. + + Raises: + ValueError: If ``epsilon`` is outside [0.0, 1.0], ``pool_threshold`` < 1, + or ``max_attempts_per_objective`` < 1. """ + # Validate scalar inputs eagerly. ``AdaptiveTechniqueSelector`` and + # ``AdaptiveStep`` perform the same checks, but only when constructed + # lazily inside ``_get_atomic_attacks_async`` (called from + # ``initialize_async``). Failing fast at __init__ matches the + # elicitation surface advertised by ``input_schema`` so wizard / + # programmatic callers get the error on the same line they supplied + # the input. + if not 0.0 <= epsilon <= 1.0: + raise ValueError(f"epsilon must be in [0.0, 1.0], got {epsilon}") + if pool_threshold < 1: + raise ValueError(f"pool_threshold must be >= 1, got {pool_threshold}") + if max_attempts_per_objective < 1: + raise ValueError(f"max_attempts_per_objective must be >= 1, got {max_attempts_per_objective}") + if not objective_scorer: objective_scorer = self._get_default_objective_scorer() self._objective_scorer: TrueFalseScorer = objective_scorer diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index 5d97b2817..6eab442df 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -147,6 +147,38 @@ def test_init_stores_adaptive_params(self, mock_get_scorer, mock_objective_score assert scenario._max_attempts_per_objective == 7 assert scenario._seed == 42 + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + @pytest.mark.parametrize("bad_epsilon", [-0.01, 1.01, 2.0, -1.0]) + def test_init_rejects_epsilon_out_of_range(self, mock_get_scorer, mock_objective_scorer, bad_epsilon): + # The inner ``AdaptiveTechniqueSelector`` already validates this, but + # only when ``_get_atomic_attacks_async`` is called from + # ``initialize_async`` — i.e. after the wizard / programmatic caller + # has already committed inputs. Fail fast at __init__ so the input + # is rejected at the elicitation surface declared by ``input_schema``. + mock_get_scorer.return_value = mock_objective_scorer + with pytest.raises(ValueError, match=r"epsilon must be in \[0.0, 1.0\]"): + TextAdaptive(epsilon=bad_epsilon) + + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + @pytest.mark.parametrize("bad_pool_threshold", [0, -1, -100]) + def test_init_rejects_pool_threshold_below_one(self, mock_get_scorer, mock_objective_scorer, bad_pool_threshold): + # Same fail-late pattern: ``AdaptiveTechniqueSelector.__init__`` validates, + # but lazily. Surface the rejection at the constructor. + mock_get_scorer.return_value = mock_objective_scorer + with pytest.raises(ValueError, match="pool_threshold must be >= 1"): + TextAdaptive(pool_threshold=bad_pool_threshold) + + @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") + @pytest.mark.parametrize("bad_max_attempts", [0, -1, -10]) + def test_init_rejects_max_attempts_below_one(self, mock_get_scorer, mock_objective_scorer, bad_max_attempts): + # ``AdaptiveStep.__init__`` validates this, but each step is built only + # inside ``_build_step_for_seed_group`` (called by + # ``_get_atomic_attacks_async``). Pull the check up to the outer + # scenario constructor so the wizard surface fails fast. + mock_get_scorer.return_value = mock_objective_scorer + with pytest.raises(ValueError, match="max_attempts_per_objective must be >= 1"): + TextAdaptive(max_attempts_per_objective=bad_max_attempts) + @pytest.mark.usefixtures(*FIXTURES) class TestTextAdaptiveAtomicAttacks: From b5a518dd33f3d2f18ae657706f6561487d685798 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:30:15 -0700 Subject: [PATCH 26/42] pin TextAdaptive policy_to_spec/spec_to_enum round-trip Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/scenario/test_waterfall.py | 87 +++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/unit/scenario/test_waterfall.py b/tests/unit/scenario/test_waterfall.py index 425826660..3d03e1aae 100644 --- a/tests/unit/scenario/test_waterfall.py +++ b/tests/unit/scenario/test_waterfall.py @@ -216,3 +216,90 @@ def test_round_trip_policy_to_spec_to_enum(name: str): specs = policy_to_spec(scenario) enums = spec_to_enum(_DummyScenario, specs) assert enums == [_DummyStrategy[name.upper()]] + + +# ---------- first-party scenario round-trips --------------------------------- +# +# The DummyScenario coverage above pins behavior against a hand-crafted strategy +# class. These tests exercise the same surface against the real first-party +# AdaptiveScenario subclass (TextAdaptive), which inherits the base +# ``_get_attack_technique_factories`` implementation and therefore relies on +# the global ``AttackTechniqueRegistry`` populated by +# ``register_scenario_techniques()``. Regressions in registry-population +# (factories built without ``source_spec``) would break the wizard's ability to +# reconstruct an Adaptive policy from CLI inputs. + + +@pytest.mark.usefixtures("patch_central_database") +class TestTextAdaptiveRoundTrip: + """Pin that registry-backed first-party scenarios round-trip cleanly.""" + + @pytest.fixture(autouse=True) + def _reset_registry(self): + from pyrit.registry import TargetRegistry + from pyrit.scenario.scenarios.adaptive.text_adaptive import TextAdaptive + + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + TextAdaptive._cached_strategy_class = None + yield + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + TextAdaptive._cached_strategy_class = None + + def _build_scenario(self, monkeypatch: pytest.MonkeyPatch): + from pyrit.prompt_target import PromptTarget + from pyrit.scenario.scenarios.adaptive.text_adaptive import TextAdaptive + + adversarial = MagicMock(spec=PromptTarget) + adversarial.get_identifier.return_value = ComponentIdentifier( + class_name="MockAdversarial", class_module="tests.unit.scenario" + ) + monkeypatch.setattr( + "pyrit.scenario.core.scenario_techniques.get_default_adversarial_target", + lambda: adversarial, + ) + monkeypatch.setattr( + "pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer", + lambda self: _make_scorer_mock(), + ) + return TextAdaptive(), TextAdaptive + + def test_default_strategies_round_trip(self, monkeypatch: pytest.MonkeyPatch): + scenario, scenario_cls = self._build_scenario(monkeypatch) + strategy_cls = scenario_cls.get_strategy_class() + default = scenario_cls.get_default_strategy() + resolved = strategy_cls.resolve(None, default=default) + scenario._scenario_strategies = resolved + + specs = policy_to_spec(scenario) + assert [sp.name for sp in specs] == [m.value for m in resolved] + + enums = spec_to_enum(scenario_cls, specs) + assert enums == resolved + + def test_explicit_leaf_subset_round_trips(self, monkeypatch: pytest.MonkeyPatch): + scenario, scenario_cls = self._build_scenario(monkeypatch) + strategy_cls = scenario_cls.get_strategy_class() + # Use the first three concrete (non-aggregate) members. The exact set + # is whatever ``SCENARIO_TECHNIQUES`` registers; we don't hardcode it. + leaves = strategy_cls.get_all_strategies()[:3] + assert len(leaves) >= 1, "TextAdaptive should register at least one technique" + scenario._scenario_strategies = leaves + + specs = policy_to_spec(scenario) + assert [sp.name for sp in specs] == [m.value for m in leaves] + + enums = spec_to_enum(scenario_cls, specs) + assert enums == leaves + + def test_every_registered_factory_carries_source_spec(self, monkeypatch: pytest.MonkeyPatch): + scenario, _ = self._build_scenario(monkeypatch) + factories = scenario._get_attack_technique_factories() + assert factories, "registry should populate factories for TextAdaptive" + missing = [name for name, fac in factories.items() if fac.source_spec is None] + assert missing == [], ( + "Every registry-built factory must expose source_spec so " + "policy_to_spec can reconstruct the technique catalog " + f"(found {len(missing)} without source_spec: {missing})" + ) From bb928bce9c74e6c01d22d2f4e9a023248b395373 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:39:33 -0700 Subject: [PATCH 27/42] validate init_async_inputs keys against initialize_async signature Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/__init__.py | 2 + pyrit/scenario/core/__init__.py | 2 + pyrit/scenario/core/builder.py | 64 ++++++++++++- tests/unit/scenario/core/test_builder.py | 112 +++++++++++++++++++++++ 4 files changed, 176 insertions(+), 4 deletions(-) diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index f5e9c24c0..62996e74b 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -60,6 +60,7 @@ materialize_opaque_inputs, policy_to_spec, spec_to_enum, + validate_init_async_inputs, validate_init_inputs, ) @@ -136,5 +137,6 @@ "materialize_opaque_inputs", "policy_to_spec", "spec_to_enum", + "validate_init_async_inputs", "validate_init_inputs", ] diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 8f6406826..22fbe5b4d 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -12,6 +12,7 @@ build_scenario_from_inputs, discover_input_schema, discover_supported_parameters, + validate_init_async_inputs, validate_init_inputs, ) from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration @@ -100,5 +101,6 @@ "policy_to_spec", "register_scenario_techniques", "spec_to_enum", + "validate_init_async_inputs", "validate_init_inputs", ] diff --git a/pyrit/scenario/core/builder.py b/pyrit/scenario/core/builder.py index 28fa960ba..318d29994 100644 --- a/pyrit/scenario/core/builder.py +++ b/pyrit/scenario/core/builder.py @@ -18,6 +18,7 @@ from __future__ import annotations +import inspect from typing import TYPE_CHECKING, Any from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag @@ -129,6 +130,52 @@ def validate_init_inputs( ) +def validate_init_async_inputs( + *, + scenario_cls: type[Scenario], + init_async_inputs: dict[str, Any], +) -> None: + """ + Validate ``init_async_inputs`` against the scenario's ``initialize_async`` signature. + + Catches unknown keyword arguments at the wizard layer before they surface as a + raw ``TypeError`` from Python's call machinery (the ``@apply_defaults`` wrapper + around ``initialize_async`` calls ``inspect.Signature.bind`` and lets a + ``TypeError`` propagate for unknown kwargs). The wizard's retry loop in + :func:`pyrit.scenario.core.input_collector.collect_inputs_with_retry` catches + only :class:`ScenarioInputValidationError`, so an unwrapped ``TypeError`` would + crash the wizard rather than re-prompt. + + Type coercion is deliberately not performed here — the collector layer is + responsible for upstream coercion of scalars. Scenarios whose + ``initialize_async`` accepts ``**kwargs`` opt out of unknown-key validation. + + Args: + scenario_cls: Concrete subclass of :class:`Scenario`. + init_async_inputs: Caller-supplied ``initialize_async`` arguments. + + Raises: + ScenarioInputValidationError: If ``init_async_inputs`` contains a keyword + argument that ``initialize_async`` does not accept. + """ + sig = inspect.signature(scenario_cls.initialize_async) + accepts_var_kw = any(p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + if accepts_var_kw: + return + + accepted = { + name + for name, param in sig.parameters.items() + if name != "self" and param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + } + unknown = sorted(set(init_async_inputs) - accepted) + if unknown: + raise ScenarioInputValidationError( + f"Unknown init_async_inputs keys for {scenario_cls.__name__}: {unknown!r}. " + f"Accepted keys: {sorted(accepted)!r}." + ) + + async def build_scenario_from_inputs( scenario_cls: type[Scenario], *, @@ -145,24 +192,33 @@ async def build_scenario_from_inputs( rather than a deep ``TypeError`` from the constructor. ``init_async_inputs`` are passed verbatim to ``initialize_async``. The - existing :meth:`Scenario.set_params_from_args` machinery already validates - these against ``supported_parameters``, so the builder does not re-validate. + existing :meth:`Scenario.set_params_from_args` machinery validates + scenario-declared parameters; here we additionally pre-check that every key + in ``init_async_inputs`` is accepted by ``initialize_async``'s signature, so + typo'd keys surface as :class:`ScenarioInputValidationError` (recoverable by + the wizard's retry loop) rather than a raw ``TypeError`` (which would crash + the loop). Scenarios whose ``initialize_async`` accepts ``**kwargs`` opt out + of the unknown-key check. Args: scenario_cls: Concrete subclass of :class:`Scenario`. init_inputs: Rich-object ``__init__`` arguments keyed by role name. All required roles from ``input_schema()`` must be present. - init_async_inputs: Scalar ``initialize_async`` arguments. Validated by + init_async_inputs: Scalar ``initialize_async`` arguments. Validated + against the ``initialize_async`` signature for unknown keys, and by ``Scenario.set_params_from_args`` via ``supported_parameters()``. Returns: Scenario: An initialized, runnable scenario instance. Raises: - ScenarioInputValidationError: If ``init_inputs`` fails validation. + ScenarioInputValidationError: If ``init_inputs`` fails validation or + ``init_async_inputs`` contains keys not accepted by + ``initialize_async``. """ schema = discover_input_schema(scenario_cls) validate_init_inputs(schema=schema, init_inputs=init_inputs) + validate_init_async_inputs(scenario_cls=scenario_cls, init_async_inputs=init_async_inputs) scenario = scenario_cls(**init_inputs) await scenario.initialize_async(**init_async_inputs) diff --git a/tests/unit/scenario/core/test_builder.py b/tests/unit/scenario/core/test_builder.py index 302e1948e..53d4c1d71 100644 --- a/tests/unit/scenario/core/test_builder.py +++ b/tests/unit/scenario/core/test_builder.py @@ -15,6 +15,7 @@ build_scenario_from_inputs, discover_input_schema, discover_supported_parameters, + validate_init_async_inputs, validate_init_inputs, ) from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag @@ -180,6 +181,53 @@ def test_empty_schema_accepts_any_inputs(self): validate_init_inputs(schema=[], init_inputs={"whatever": 1}) +class TestValidateInitAsyncInputs: + def test_accepts_all_known_keys(self): + validate_init_async_inputs( + scenario_cls=cast("Any", _FakeScenarioScalarRoles), + init_async_inputs={"max_concurrency": 4}, + ) + + def test_accepts_empty(self): + validate_init_async_inputs( + scenario_cls=cast("Any", _FakeScenarioScalarRoles), + init_async_inputs={}, + ) + + def test_rejects_unknown_key(self): + with pytest.raises(ScenarioInputValidationError) as exc_info: + validate_init_async_inputs( + scenario_cls=cast("Any", _FakeScenarioScalarRoles), + init_async_inputs={"typo": 1}, + ) + message = str(exc_info.value) + assert "typo" in message + # The error should also list what *is* accepted to help the user recover. + assert "max_concurrency" in message + + def test_rejects_multiple_unknown_keys(self): + with pytest.raises(ScenarioInputValidationError) as exc_info: + validate_init_async_inputs( + scenario_cls=cast("Any", _FakeScenarioScalarRoles), + init_async_inputs={"foo": 1, "bar": 2}, + ) + message = str(exc_info.value) + assert "foo" in message and "bar" in message + + def test_var_keyword_opts_out(self): + """Scenarios whose ``initialize_async`` accepts ``**kwargs`` skip the check.""" + + class _VarKw(_FakeScenarioBase): + async def initialize_async(self, **kwargs: Any) -> None: + pass + + # Should not raise even with an arbitrary key. + validate_init_async_inputs( + scenario_cls=cast("Any", _VarKw), + init_async_inputs={"anything": "goes"}, + ) + + class TestBuildScenarioFromInputs: async def test_constructs_and_initializes(self): scenario = await build_scenario_from_inputs( @@ -235,6 +283,70 @@ async def test_choice_validation_fires(self): ) assert exc_info.value.role_name == "mode" + async def test_init_async_inputs_unknown_key_raises_validation_error(self): + """Lead 2: an unknown init_async_inputs key surfaces as ScenarioInputValidationError. + + Without pre-validation, ``initialize_async(**init_async_inputs)`` blows up with + a raw ``TypeError`` from Python's call machinery (via the @apply_defaults + wrapper's ``sig.bind``). The wizard's retry loop catches only + ``ScenarioInputValidationError``, so a typo'd flag would crash the wizard + instead of surfacing as a recoverable validation error. + """ + with pytest.raises(ScenarioInputValidationError) as exc_info: + await build_scenario_from_inputs( + cast("Any", _FakeScenarioScalarRoles), + init_inputs={"weakness_label": "harm"}, + init_async_inputs={"max_concurrency": 4, "bogus_typo": "value"}, + ) + assert "bogus_typo" in str(exc_info.value) + + async def test_init_async_inputs_multiple_unknown_keys_all_listed(self): + with pytest.raises(ScenarioInputValidationError) as exc_info: + await build_scenario_from_inputs( + cast("Any", _FakeScenarioScalarRoles), + init_inputs={"weakness_label": "harm"}, + init_async_inputs={"alpha": 1, "beta": 2}, + ) + message = str(exc_info.value) + assert "alpha" in message and "beta" in message + + async def test_init_async_inputs_unknown_key_does_not_construct_scenario(self): + """Validation must run before ``__init__`` to avoid orphaned construction side effects.""" + + class _TracksConstruction(_FakeScenarioBase): + constructed = False + + def __init__(self) -> None: + type(self).constructed = True + + async def initialize_async(self) -> None: + pass + + with pytest.raises(ScenarioInputValidationError): + await build_scenario_from_inputs( + cast("Any", _TracksConstruction), + init_inputs={}, + init_async_inputs={"bogus": 1}, + ) + assert _TracksConstruction.constructed is False + + async def test_init_async_inputs_var_keyword_accepts_anything(self): + """A scenario whose ``initialize_async`` accepts ``**kwargs`` opts out of validation.""" + + class _VarKw(_FakeScenarioBase): + def __init__(self) -> None: + pass + + async def initialize_async(self, **kwargs: Any) -> None: + self.received_kwargs = kwargs + + scenario = await build_scenario_from_inputs( + cast("Any", _VarKw), + init_inputs={}, + init_async_inputs={"whatever": "ok"}, + ) + assert scenario.received_kwargs == {"whatever": "ok"} # type: ignore[attr-defined] + class TestScenarioInputValidationError: def test_is_value_error_subclass(self): From 5a4f46399f2544b971789e52bf0e975090f89548 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:42:47 -0700 Subject: [PATCH 28/42] pin InputCollector runtime_checkable + ArtifactInputCollector defensive copy Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenario/core/test_input_collector.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/unit/scenario/core/test_input_collector.py b/tests/unit/scenario/core/test_input_collector.py index c03fa433c..4521869bc 100644 --- a/tests/unit/scenario/core/test_input_collector.py +++ b/tests/unit/scenario/core/test_input_collector.py @@ -242,6 +242,42 @@ def test_opaque_payload_passes_through(self): collector = ArtifactInputCollector({"instance": opaque_payload}) assert collector.collect(role=_opaque_role()) == opaque_payload + def test_defensive_copy_of_input_mapping(self): + """Mutating the source dict after construction does not affect the collector.""" + source: dict[str, Any] = {"label": "first"} + collector = ArtifactInputCollector(source) + source["label"] = "second" + assert collector.collect(role=_scalar_role()) == "first" + + def test_implements_input_collector_protocol(self): + assert isinstance(ArtifactInputCollector({}), InputCollector) + + +class TestInputCollectorProtocolNegative: + """Pin that ``@runtime_checkable`` does not over-match objects without ``collect``.""" + + def test_str_is_not_collector(self): + assert not isinstance("not a collector", InputCollector) + + def test_dict_is_not_collector(self): + assert not isinstance({"label": "v"}, InputCollector) + + def test_object_with_unrelated_attrs_is_not_collector(self): + class _NotACollector: + def some_other_method(self) -> None: + pass + + assert not isinstance(_NotACollector(), InputCollector) + + def test_object_with_collect_attr_is_collector(self): + """``@runtime_checkable`` only checks attribute presence, not signature.""" + + class _DuckTyped: + def collect(self, *, role: Any, error: Any = None, attempt: int = 0) -> Any: + return "ok" + + assert isinstance(_DuckTyped(), InputCollector) + # --- collect_inputs_with_retry --------------------------------------------------- From b2939198106c6461be0cbfcdbc8f35acbc6cf7c6 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 16:44:15 -0700 Subject: [PATCH 29/42] pin collect_inputs_with_retry only catches ScenarioInputValidationError Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenario/core/test_input_collector.py | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/unit/scenario/core/test_input_collector.py b/tests/unit/scenario/core/test_input_collector.py index 4521869bc..199d2b66e 100644 --- a/tests/unit/scenario/core/test_input_collector.py +++ b/tests/unit/scenario/core/test_input_collector.py @@ -353,3 +353,77 @@ def test_message_points_at_artifact_path(self): def test_is_not_implemented_error(self): """Allows callers to use the broader ``NotImplementedError`` catch.""" assert issubclass(OpaqueRoleNotElicitableError, NotImplementedError) + + +class TestCollectInputsWithRetryPropagation: + """Pin that non-``ScenarioInputValidationError`` exceptions propagate immediately. + + The retry loop must only catch :class:`ScenarioInputValidationError`; every other + exception type (including :class:`OpaqueRoleNotElicitableError`, + :class:`KeyboardInterrupt`, or any unrelated ``Exception``) must escape on the + first attempt so the wizard's outer handler can present the correct guidance. + Without this guarantee, an over-broad ``except`` in the retry loop would burn + the entire ``max_attempts`` budget on an unrecoverable error before raising + ``MaxAttemptsExceededError`` — masking the original cause from the wizard. + """ + + def test_opaque_role_not_elicitable_error_propagates_immediately(self): + class _OpaqueRaiser: + def __init__(self) -> None: + self.call_count = 0 + + def collect(self, *, role: RoleDescriptor, error: Exception | None = None, attempt: int = 0) -> Any: + self.call_count += 1 + raise OpaqueRoleNotElicitableError(role.name) + + collector = _OpaqueRaiser() + with pytest.raises(OpaqueRoleNotElicitableError) as exc_info: + collect_inputs_with_retry(collector=collector, schema=[_scalar_role(name="atomic")], max_attempts=5) + assert exc_info.value.role_name == "atomic" + assert collector.call_count == 1, "Opaque error must NOT burn retry budget" + + def test_unrelated_exception_propagates_immediately(self): + """A generic ``RuntimeError`` from a collector escapes without retry.""" + + class _BoomCollector: + def __init__(self) -> None: + self.call_count = 0 + + def collect(self, *, role: RoleDescriptor, error: Exception | None = None, attempt: int = 0) -> Any: + self.call_count += 1 + raise RuntimeError("kaboom") + + collector = _BoomCollector() + with pytest.raises(RuntimeError, match="kaboom"): + collect_inputs_with_retry(collector=collector, schema=[_scalar_role()], max_attempts=5) + assert collector.call_count == 1 + + def test_keyboard_interrupt_propagates_immediately(self): + """``KeyboardInterrupt`` (BaseException) must not be swallowed.""" + + class _CtrlC: + def __init__(self) -> None: + self.call_count = 0 + + def collect(self, *, role: RoleDescriptor, error: Exception | None = None, attempt: int = 0) -> Any: + self.call_count += 1 + raise KeyboardInterrupt + + collector = _CtrlC() + with pytest.raises(KeyboardInterrupt): + collect_inputs_with_retry(collector=collector, schema=[_scalar_role()], max_attempts=5) + assert collector.call_count == 1 + + def test_opaque_error_on_later_role_does_not_keep_prior_attempts(self): + """Opaque error mid-schema raises cleanly, no MaxAttemptsExceededError swallow.""" + + class _MixedCollector: + def collect(self, *, role: RoleDescriptor, error: Exception | None = None, attempt: int = 0) -> Any: + if role.name == "first": + return "ok" + raise OpaqueRoleNotElicitableError(role.name) + + schema = [_scalar_role(name="first"), _opaque_role(name="second")] + with pytest.raises(OpaqueRoleNotElicitableError) as exc_info: + collect_inputs_with_retry(collector=_MixedCollector(), schema=schema, max_attempts=5) + assert exc_info.value.role_name == "second" From 24211e32025384eb606b2f4deddf1262524fe7ed Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 17:26:45 -0700 Subject: [PATCH 30/42] FEAT: Phase 8d pyrit_wizard CLI --- pyproject.toml | 1 + pyrit/cli/pyrit_wizard.py | 757 ++++++++++++++++++++++++++++ tests/unit/cli/test_pyrit_wizard.py | 555 ++++++++++++++++++++ 3 files changed, 1313 insertions(+) create mode 100644 pyrit/cli/pyrit_wizard.py create mode 100644 tests/unit/cli/test_pyrit_wizard.py diff --git a/pyproject.toml b/pyproject.toml index bdc563d00..16be7f8aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -160,6 +160,7 @@ all = [ pyrit_backend = "pyrit.cli.pyrit_backend:main" pyrit_scan = "pyrit.cli.pyrit_scan:main" pyrit_shell = "pyrit.cli.pyrit_shell:main" +pyrit_wizard = "pyrit.cli.pyrit_wizard:main" [tool.pytest.ini_options] addopts = [ diff --git a/pyrit/cli/pyrit_wizard.py b/pyrit/cli/pyrit_wizard.py new file mode 100644 index 000000000..4ce83bf29 --- /dev/null +++ b/pyrit/cli/pyrit_wizard.py @@ -0,0 +1,757 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +PyRIT CLI - Interactive wizard for assembling a scenario from declared inputs. + +The wizard composes the Phase 8a/8b/8c/8g surfaces: + +* :func:`pyrit.scenario.core.builder.discover_input_schema` — + rich-object ``__init__`` roles declared by the scenario class. +* :func:`pyrit.scenario.core.builder.discover_supported_parameters` — + scalar ``initialize_async`` parameters (legacy contract). +* :class:`pyrit.scenario.core.input_collector.CliInputCollector` — + stdin/stdout elicitation with the error-recovery retry loop. +* :func:`pyrit.scenario.core.builder.build_scenario_from_inputs` — + the single entry point that constructs + initializes the scenario. +* :func:`pyrit.scenario.core.graph_artifact.build_graph_artifact` / + :func:`pyrit.scenario.core.graph_artifact.graph_artifact_to_yaml` — + optional ``--save artifact.yaml`` capture. + +Use ``pyrit_scan --from-artifact path.yaml`` (Phase 8e) to replay an artifact. + +Examples:: + + # List available scenarios and exit + pyrit_wizard --list-scenarios + + # Pick a scenario interactively, prompting for every input + pyrit_wizard --target my_target --initializers target + + # Skip the interactive picker, save an artifact, but don't run yet + pyrit_wizard foundry.red_team_agent --target my_target --save my_run.yaml \\ + --initializers target + + # Build + run end-to-end with one elicitation pass + pyrit_wizard garak.encoding --target my_target --run --initializers target +""" + +from __future__ import annotations + +import asyncio +import logging +import sys +from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +from pyrit.cli._cli_args import ( + ARG_HELP, + _parse_initializer_arg, + non_negative_int, + positive_int, + validate_log_level_argparse, +) + +if TYPE_CHECKING: + from pyrit.cli.frontend_core import FrontendCore + from pyrit.common.parameter import Parameter + from pyrit.scenario.core import Scenario + from pyrit.scenario.core.input_schema import RoleDescriptor + + +_DESCRIPTION = """PyRIT Wizard - Interactively assemble and (optionally) run a scenario. + +Walks the scenario's declared input_schema() (rich __init__ roles) and +supported_parameters() (scalar initialize_async args) via an interactive +prompt, then builds the scenario, optionally saves a reproducible artifact, +and optionally runs it. + +Examples: + # List available scenarios + pyrit_wizard --list-scenarios + + # Wizard for a specific scenario, save artifact, don't run + pyrit_wizard foundry.red_team_agent --target my_target --save out.yaml \\ + --initializers target + + # Build + run end-to-end + pyrit_wizard garak.encoding --target my_target --run --initializers target + +Note: Scenarios whose input_schema() declares OPAQUE roles (e.g. pre-built +AtomicAttack or OutcomeScorer instances) cannot be elicited end-to-end from a +CLI. The wizard exits with a helpful pointer to the programmatic API or +``pyrit_scan --from-artifact path.yaml``. +""" + + +def _build_parser() -> ArgumentParser: + """ + Build the ``pyrit_wizard`` argparse parser. + + Returns: + ArgumentParser: The fully configured argument parser. + """ + parser = ArgumentParser( + prog="pyrit_wizard", + description=_DESCRIPTION, + formatter_class=RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "scenario_name", + type=str, + nargs="?", + help="Name of the scenario to wizard. If omitted, you'll pick from a menu.", + ) + + parser.add_argument( + "--config-file", + type=Path, + help=ARG_HELP["config_file"], + ) + + parser.add_argument( + "--log-level", + type=validate_log_level_argparse, + default=logging.WARNING, + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", + ) + + parser.add_argument( + "--list-scenarios", + action="store_true", + help="List all available scenarios and exit", + ) + + parser.add_argument( + "--list-targets", + action="store_true", + help="List registered targets and exit. Requires initializers that register targets " + "(e.g., --initializers target).", + ) + + parser.add_argument( + "--initializers", + type=_parse_initializer_arg, + nargs="+", + help=ARG_HELP["initializers"], + ) + + parser.add_argument( + "--initialization-scripts", + type=str, + nargs="+", + help=ARG_HELP["initialization_scripts"], + ) + + parser.add_argument( + "--target", + type=str, + help="Name of a registered target. If omitted, the wizard prompts interactively.", + ) + + parser.add_argument( + "--save", + type=Path, + help="Write the built scenario as a graph artifact (YAML) to this path.", + ) + + parser.add_argument( + "--run", + action="store_true", + help="Run the built scenario after elicitation completes.", + ) + + parser.add_argument( + "--max-concurrency", + type=positive_int, + help=ARG_HELP["max_concurrency"], + ) + + parser.add_argument( + "--max-retries", + type=non_negative_int, + help=ARG_HELP["max_retries"], + ) + + parser.add_argument( + "--max-attempts-per-role", + type=positive_int, + default=5, + help="Maximum re-prompt attempts for any single role before bailing (default: 5).", + ) + + parser.add_argument( + "--non-interactive", + action="store_true", + help="Disable interactive prompts; fail fast on any missing input. " + "Useful for scripted runs that supply every input via --config-file.", + ) + + return parser + + +def parse_args(args: Optional[list[str]] = None) -> Namespace: + """ + Parse command-line arguments for ``pyrit_wizard``. + + Args: + args: Optional argv list (mainly for tests). ``None`` means use ``sys.argv``. + + Returns: + Namespace: The parsed argument namespace. + """ + return _build_parser().parse_args(args) + + +def _role_from_parameter(parameter: Parameter) -> RoleDescriptor: + """ + Adapt a :class:`Parameter` (scalar ``initialize_async`` declaration) to a + :class:`RoleDescriptor` so the same :class:`CliInputCollector` retry loop + can elicit both halves of the lifecycle. + + Choice-typed parameters become :attr:`RoleTag.CHOICE`; everything else is + treated as a :attr:`RoleTag.SCALAR`. Optionality follows the parameter's + default: a non-None default makes the role optional. + + The collector treats a role as required iff ``required=True``, so we set + ``required=False`` whenever the parameter declares a default (the + scenario's ``set_params_from_args`` will materialize the default if the + user skips the prompt). + + Args: + parameter: The :class:`Parameter` declared by the scenario. + + Returns: + RoleDescriptor: The adapted descriptor consumed by :class:`CliInputCollector`. + """ + from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + + tag = RoleTag.CHOICE if parameter.choices is not None else RoleTag.SCALAR + required = parameter.default is None and parameter.choices is None + return RoleDescriptor( + name=parameter.name, + description=parameter.description, + tag=tag, + param_type=parameter.param_type, + choices=parameter.choices, + default=parameter.default, + required=required, + ) + + +def _pick_from_menu(*, prompt: str, options: list[str]) -> str: + """ + Show a numbered menu of ``options`` and return the chosen item. + + Accepts either a 1-based index or an exact item name. Re-prompts on + invalid input until the user provides a recognized value or EOFs. + + Args: + prompt: Prompt text shown above the option list. + options: Available choices to display, in order. + + Returns: + str: The chosen option. + + Raises: + ValueError: When ``options`` is empty. + """ + if not options: + raise ValueError("Cannot pick from an empty option list.") + + while True: + print(prompt) + for index, option in enumerate(options, start=1): + print(f" {index}. {option}") + raw = input("> ").strip() + if raw in options: + return raw + try: + picked = int(raw) + except ValueError: + print(f" ! Invalid input {raw!r}; enter a number 1-{len(options)} or an item name.") + continue + if 1 <= picked <= len(options): + return options[picked - 1] + print(f" ! Index {picked} out of range; enter a number 1-{len(options)}.") + + +def _resolve_scenario_class( + *, + context: FrontendCore, + explicit_name: str | None, + interactive: bool, +) -> tuple[type[Scenario], str]: + """ + Resolve the scenario class to wizard. + + Returns the class and the registry name it was found under so callers can + print user-friendly progress messages. When ``explicit_name`` is provided + we look it up directly; otherwise (when ``interactive=True``) we offer a + numbered menu of registered scenarios. + + Args: + context: The initialized :class:`FrontendCore` (registry source). + explicit_name: Scenario name supplied via positional arg, if any. + interactive: When True, the wizard may prompt for missing inputs. + + Returns: + tuple[type[Scenario], str]: Resolved scenario class and the registry + name it was looked up under. + + Raises: + RuntimeError: When the scenario registry is empty. + ValueError: When ``explicit_name`` is unknown, or when no name was + supplied and ``interactive`` is False. + """ + registry = context.scenario_registry + names = sorted(registry.get_names()) + if not names: + raise RuntimeError("No scenarios are registered. Did you forget --initializers or --initialization-scripts?") + + if explicit_name: + try: + scenario_class = registry.get_class(explicit_name) + except KeyError as exc: + raise ValueError(f"Scenario {explicit_name!r} not found. Available: {', '.join(names)}") from exc + return scenario_class, explicit_name + + if not interactive: + raise ValueError( + "No scenario_name provided and --non-interactive was set. " + "Either pass a positional scenario name or drop --non-interactive." + ) + + chosen = _pick_from_menu(prompt="Available scenarios:", options=names) + return registry.get_class(chosen), chosen + + +def _resolve_target_name( + *, + explicit_name: str | None, + available_targets: list[str], + interactive: bool, +) -> str: + """ + Resolve the objective_target name, prompting interactively when needed. + + Args: + explicit_name: Target name supplied via ``--target``, if any. + available_targets: Registered target names from the populated + :class:`TargetRegistry`. + interactive: When True, the wizard may prompt the user. + + Returns: + str: The resolved target name. + + Raises: + RuntimeError: When no targets are registered. + ValueError: When ``explicit_name`` is unknown, or when no target was + supplied and ``interactive`` is False. + """ + if explicit_name is not None: + if explicit_name not in available_targets: + raise ValueError( + f"Target {explicit_name!r} not found. Available: {', '.join(available_targets) or '(none)'}" + ) + return explicit_name + + if not available_targets: + raise RuntimeError( + "No targets are registered. Did you forget --initializers target " + "or --initialization-scripts pointing at a script that registers a target?" + ) + + if not interactive: + raise ValueError("No --target provided and --non-interactive was set. Pass --target NAME explicitly.") + + return _pick_from_menu(prompt="Available targets:", options=available_targets) + + +def _resolve_strategies( + *, + scenario_class: type[Scenario], + interactive: bool, +) -> list[Any] | None: + """ + Optionally elicit scenario strategies from the user. + + Returns ``None`` when the user skips the prompt (the scenario then uses + its declared default aggregate). Returns a non-empty list of enum members + when the user supplied at least one selection. + + Args: + scenario_class: The scenario whose strategy enum should be elicited. + interactive: When False, the wizard skips strategy elicitation entirely. + + Returns: + list[Any] | None: Selected enum members, or ``None`` if the user skipped + the prompt or the scenario has no declared strategy enum. + + Raises: + ValueError: When a supplied strategy name or index is invalid. + """ + if not interactive: + return None + + try: + strategy_class = scenario_class.get_strategy_class() + except Exception: + return None + + members = [member.value for member in strategy_class] + if not members: + return None + + print( + f"\nScenario strategies for {scenario_class.__name__} (comma-separated names or numbers, blank = use defaults):" + ) + for index, member in enumerate(members, start=1): + print(f" {index}. {member}") + raw = input("> ").strip() + if not raw: + return None + + tokens = [token.strip() for token in raw.split(",") if token.strip()] + selected: list[Any] = [] + for token in tokens: + if token in members: + selected.append(strategy_class(token)) + continue + try: + index = int(token) + except ValueError as exc: + raise ValueError(f"Strategy {token!r} not recognized. Available: {', '.join(members)}") from exc + if not 1 <= index <= len(members): + raise ValueError(f"Strategy index {index} out of range (1-{len(members)}).") + selected.append(strategy_class(members[index - 1])) + + return selected or None + + +def _collect_init_inputs( + *, + scenario_class: type[Scenario], + interactive: bool, + max_attempts: int, +) -> dict[str, Any]: + """ + Walk the scenario's ``input_schema()`` and collect rich-object init inputs. + + OPAQUE roles surface a clear error pointing at the artifact-replay + workflow — the wizard cannot elicit pre-built ``Identifiable`` instances + from stdin alone. + + Args: + scenario_class: The concrete scenario being constructed. + interactive: When False, declared schema items raise instead of prompting. + max_attempts: Per-role retry budget before bailing with + :class:`MaxAttemptsExceededError`. + + Returns: + dict[str, Any]: Collected values keyed by role name. + + Raises: + RuntimeError: When the schema includes an OPAQUE role the CLI cannot elicit. + ValueError: When the schema is non-empty but ``interactive`` is False. + """ + from pyrit.scenario.core.builder import discover_input_schema + from pyrit.scenario.core.input_collector import ( + CliInputCollector, + OpaqueRoleNotElicitableError, + collect_inputs_with_retry, + ) + + schema = discover_input_schema(scenario_class) + if not schema: + return {} + + if not interactive: + raise ValueError( + f"{scenario_class.__name__}.input_schema() declares {len(schema)} role(s) " + "but --non-interactive was set. Use the programmatic API " + "(build_scenario_from_inputs) or --config-file to supply them." + ) + + print(f"\n-- Constructor inputs for {scenario_class.__name__} --") + collector = CliInputCollector() + try: + return collect_inputs_with_retry( + collector=collector, + schema=schema, + max_attempts=max_attempts, + ) + except OpaqueRoleNotElicitableError as exc: + raise RuntimeError( + f"{scenario_class.__name__} declares opaque constructor inputs that cannot be " + f"elicited from the CLI: {exc}\n" + "Build it programmatically via `build_scenario_from_inputs(...)`, or use a " + "previously saved artifact via `pyrit_scan --from-artifact path.yaml`." + ) from exc + + +def _collect_init_async_inputs( + *, + scenario_class: type[Scenario], + interactive: bool, + max_attempts: int, + cli_overrides: dict[str, Any], +) -> dict[str, Any]: + """ + Walk the scenario's ``supported_parameters()`` and collect init-async inputs. + + Adapts each :class:`Parameter` to a :class:`RoleDescriptor` and runs the + same retry loop as constructor inputs. CLI overrides (``--max-concurrency``, + ``--max-retries``) win over any elicited value. + + Args: + scenario_class: The concrete scenario being initialized. + interactive: When False, the wizard skips schema-driven prompting and + relies on ``cli_overrides`` plus the scenario's own defaults. + max_attempts: Per-role retry budget before bailing. + cli_overrides: Values supplied on the command line. ``None`` entries + are ignored; non-``None`` entries always win over prompts. + + Returns: + dict[str, Any]: Collected ``initialize_async`` kwargs. + """ + from pyrit.scenario.core.builder import discover_supported_parameters + from pyrit.scenario.core.input_collector import ( + CliInputCollector, + collect_inputs_with_retry, + ) + + parameters = discover_supported_parameters(scenario_class) + if not parameters: + collected: dict[str, Any] = {} + elif not interactive: + # In non-interactive mode the user is expected to supply everything via + # CLI overrides (or accept declared defaults via set_params_from_args). + collected = {} + else: + print(f"\n-- Runtime parameters for {scenario_class.__name__} --") + schema = [_role_from_parameter(parameter) for parameter in parameters] + collected = collect_inputs_with_retry( + collector=CliInputCollector(), + schema=schema, + max_attempts=max_attempts, + ) + + # CLI overrides land on top so flags always win over prompts. + collected.update({key: value for key, value in cli_overrides.items() if value is not None}) + return collected + + +def _print_scenarios_list(*, context: FrontendCore) -> int: + """ + Print a one-line summary of every registered scenario and exit. + + Args: + context: The initialized :class:`FrontendCore`. + + Returns: + int: Process exit code (always ``0``). + """ + registry = context.scenario_registry + metadata_list = sorted(registry.list_metadata(), key=lambda m: m.registry_name) + if not metadata_list: + print("No scenarios are registered.") + return 0 + print(f"Registered scenarios ({len(metadata_list)}):") + for metadata in metadata_list: + description = metadata.class_description.splitlines()[0] if metadata.class_description else "" + print(f" {metadata.registry_name}{f' — {description}' if description else ''}") + return 0 + + +def _print_targets_list(*, target_names: list[str]) -> int: + """ + Print the list of registered targets and exit. + + Args: + target_names: Names returned by :func:`frontend_core.list_targets_async`. + + Returns: + int: Process exit code (always ``0``). + """ + if not target_names: + print( + "No targets are registered. Pass --initializers target (or another " + "initializer that registers targets) and try again." + ) + return 0 + print(f"Registered targets ({len(target_names)}):") + for name in target_names: + print(f" {name}") + return 0 + + +async def _run_wizard_async(*, parsed_args: Namespace) -> int: + """ + Drive the wizard end-to-end. + + Returns an integer exit code suitable for ``sys.exit``. + + Args: + parsed_args: The parsed argparse namespace. + + Returns: + int: ``0`` on success, ``1`` on any handled error surfaced to stdout. + + Raises: + RuntimeError: When the target registry drops the selected target between + elicitation and build (defensive — indicates a fixture race). + """ + # Deferred imports so ``--help`` stays instant. + from pyrit.cli import frontend_core + from pyrit.scenario.core.builder import build_scenario_from_inputs + from pyrit.scenario.core.graph_artifact import ( + build_graph_artifact, + graph_artifact_to_yaml, + ) + + interactive = not parsed_args.non_interactive + + initialization_scripts = None + if parsed_args.initialization_scripts: + try: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + except FileNotFoundError as exc: + print(f"Error: {exc}") + return 1 + + context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, + initialization_scripts=initialization_scripts, + initializer_names=parsed_args.initializers, + log_level=parsed_args.log_level, + ) + await context.initialize_async() + + if parsed_args.list_scenarios: + return _print_scenarios_list(context=context) + + # Populate the target registry via the same initializer flow pyrit_scan uses. + target_names = await frontend_core.list_targets_async(context=context) + + if parsed_args.list_targets: + return _print_targets_list(target_names=target_names) + + scenario_class, scenario_name = _resolve_scenario_class( + context=context, + explicit_name=parsed_args.scenario_name, + interactive=interactive, + ) + print(f"\nWizard: {scenario_name} ({scenario_class.__name__})") + + target_name = _resolve_target_name( + explicit_name=parsed_args.target, + available_targets=target_names, + interactive=interactive, + ) + + init_inputs = _collect_init_inputs( + scenario_class=scenario_class, + interactive=interactive, + max_attempts=parsed_args.max_attempts_per_role, + ) + + cli_overrides: dict[str, Any] = { + "max_concurrency": parsed_args.max_concurrency, + "max_retries": parsed_args.max_retries, + } + init_async_inputs = _collect_init_async_inputs( + scenario_class=scenario_class, + interactive=interactive, + max_attempts=parsed_args.max_attempts_per_role, + cli_overrides=cli_overrides, + ) + + # objective_target / scenario_strategies are injected after collection so + # they always reflect explicit CLI or interactive picker choices, not a + # value the user may have typed at a generic "objective_target" prompt. + from pyrit.registry import TargetRegistry + + target_registry = TargetRegistry.get_registry_singleton() + objective_target = target_registry.get_instance_by_name(target_name) + if objective_target is None: + # Defensive: list_targets_async populated the registry, so absence here + # means a race with another test fixture clearing it. Surface a clear + # message rather than letting build raise a TypeError downstream. + raise RuntimeError( + f"Target {target_name!r} disappeared from registry between elicitation and build. Re-run the wizard." + ) + init_async_inputs["objective_target"] = objective_target + + strategies = _resolve_strategies( + scenario_class=scenario_class, + interactive=interactive, + ) + if strategies is not None: + init_async_inputs["scenario_strategies"] = strategies + + print(f"\nBuilding {scenario_class.__name__}...") + sys.stdout.flush() + scenario = await build_scenario_from_inputs( + scenario_class, + init_inputs=init_inputs, + init_async_inputs=init_async_inputs, + ) + print(" ok.") + + if parsed_args.save is not None: + artifact = build_graph_artifact( + scenario, + init_inputs=init_inputs, + init_async_inputs=init_async_inputs, + ) + graph_artifact_to_yaml(artifact, parsed_args.save) + print(f" artifact saved -> {parsed_args.save}") + + if parsed_args.run: + print(f"\nRunning {scenario_class.__name__}...") + sys.stdout.flush() + result = await scenario.run_async() + from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter + + printer = PrettyScenarioResultMemoryPrinter() + await printer.print_summary_async(result) + else: + print( + "\nDone. Pass --run to execute, or rerun with --save PATH and replay via `pyrit_scan --from-artifact PATH`." + ) + + return 0 + + +def main(args: Optional[list[str]] = None) -> int: + """ + Start the PyRIT wizard CLI. + + Returns: + int: Exit code (0 on success, 1 on error). + """ + try: + parsed_args = parse_args(args) + except SystemExit as exc: + return exc.code if isinstance(exc.code, int) else 1 + + print("Starting PyRIT wizard...") + sys.stdout.flush() + + try: + return asyncio.run(_run_wizard_async(parsed_args=parsed_args)) + except KeyboardInterrupt: + print("\nAborted by user.") + return 130 + except Exception as exc: + print(f"\nError: {exc}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/unit/cli/test_pyrit_wizard.py b/tests/unit/cli/test_pyrit_wizard.py new file mode 100644 index 000000000..56b962e36 --- /dev/null +++ b/tests/unit/cli/test_pyrit_wizard.py @@ -0,0 +1,555 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for the ``pyrit_wizard`` CLI module.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.cli import pyrit_wizard +from pyrit.common.parameter import Parameter +from pyrit.scenario.core.input_collector import OpaqueRoleNotElicitableError +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag + +# --- parse_args ---------------------------------------------------------------- + + +def test_parse_args_defaults_for_bare_scenario_name(): + args = pyrit_wizard.parse_args(["my_scenario"]) + + assert args.scenario_name == "my_scenario" + assert args.list_scenarios is False + assert args.list_targets is False + assert args.target is None + assert args.save is None + assert args.run is False + assert args.log_level == logging.WARNING + assert args.max_attempts_per_role == 5 + assert args.non_interactive is False + + +def test_parse_args_list_scenarios_flag(): + args = pyrit_wizard.parse_args(["--list-scenarios"]) + + assert args.list_scenarios is True + assert args.scenario_name is None + + +def test_parse_args_full_command_line(): + args = pyrit_wizard.parse_args( + [ + "my_scenario", + "--target", + "my_target", + "--save", + "out.yaml", + "--run", + "--initializers", + "target", + "--max-concurrency", + "4", + "--max-retries", + "2", + "--max-attempts-per-role", + "3", + "--non-interactive", + ] + ) + + assert args.scenario_name == "my_scenario" + assert args.target == "my_target" + assert args.save == Path("out.yaml") + assert args.run is True + assert args.initializers == ["target"] + assert args.max_concurrency == 4 + assert args.max_retries == 2 + assert args.max_attempts_per_role == 3 + assert args.non_interactive is True + + +# --- _role_from_parameter ------------------------------------------------------- + + +def test_role_from_parameter_scalar_required_when_no_default(): + parameter = Parameter(name="threshold", description="cutoff", param_type=int) + + role = pyrit_wizard._role_from_parameter(parameter) + + assert role.name == "threshold" + assert role.tag is RoleTag.SCALAR + assert role.param_type is int + assert role.required is True + assert role.default is None + + +def test_role_from_parameter_optional_when_default_present(): + parameter = Parameter(name="batch_size", description="batch", default=10, param_type=int) + + role = pyrit_wizard._role_from_parameter(parameter) + + assert role.required is False + assert role.default == 10 + + +def test_role_from_parameter_choice_when_choices_declared(): + parameter = Parameter( + name="mode", + description="execution mode", + default="fast", + param_type=str, + choices=("fast", "thorough"), + ) + + role = pyrit_wizard._role_from_parameter(parameter) + + assert role.tag is RoleTag.CHOICE + assert role.choices == ("fast", "thorough") + assert role.required is False + + +# --- _pick_from_menu ------------------------------------------------------------ + + +def test_pick_from_menu_accepts_index(): + with patch("builtins.input", return_value="2"): + choice = pyrit_wizard._pick_from_menu(prompt="pick:", options=["a", "b", "c"]) + + assert choice == "b" + + +def test_pick_from_menu_accepts_exact_name(): + with patch("builtins.input", return_value="c"): + choice = pyrit_wizard._pick_from_menu(prompt="pick:", options=["a", "b", "c"]) + + assert choice == "c" + + +def test_pick_from_menu_reprompts_on_invalid_then_succeeds(capsys): + inputs = iter(["foo", "9", "1"]) + with patch("builtins.input", side_effect=lambda _: next(inputs)): + choice = pyrit_wizard._pick_from_menu(prompt="pick:", options=["alpha", "beta"]) + + assert choice == "alpha" + captured = capsys.readouterr() + assert "Invalid input" in captured.out + assert "out of range" in captured.out + + +def test_pick_from_menu_raises_on_empty_options(): + with pytest.raises(ValueError, match="empty"): + pyrit_wizard._pick_from_menu(prompt="pick:", options=[]) + + +# --- _resolve_scenario_class ---------------------------------------------------- + + +def _make_context_with_scenarios(scenarios: dict[str, type]) -> MagicMock: + registry = MagicMock() + registry.get_names.return_value = list(scenarios) + registry.get_class.side_effect = lambda name: scenarios[name] + context = MagicMock() + context.scenario_registry = registry + return context + + +def test_resolve_scenario_class_uses_explicit_name(): + fake_cls = type("FakeScenario", (), {}) + context = _make_context_with_scenarios({"my_scenario": fake_cls}) + + cls, name = pyrit_wizard._resolve_scenario_class(context=context, explicit_name="my_scenario", interactive=True) + + assert cls is fake_cls + assert name == "my_scenario" + + +def test_resolve_scenario_class_raises_when_explicit_missing(): + fake_cls = type("Foo", (), {}) + context = _make_context_with_scenarios({"foo": fake_cls}) + context.scenario_registry.get_class.side_effect = KeyError("bar") + + with pytest.raises(ValueError, match="not found"): + pyrit_wizard._resolve_scenario_class(context=context, explicit_name="bar", interactive=True) + + +def test_resolve_scenario_class_interactive_pick(): + cls_a = type("A", (), {}) + cls_b = type("B", (), {}) + context = _make_context_with_scenarios({"alpha": cls_a, "beta": cls_b}) + + with patch("builtins.input", return_value="beta"): + cls, name = pyrit_wizard._resolve_scenario_class(context=context, explicit_name=None, interactive=True) + + assert cls is cls_b + assert name == "beta" + + +def test_resolve_scenario_class_non_interactive_no_name_raises(): + context = _make_context_with_scenarios({"foo": type("Foo", (), {})}) + + with pytest.raises(ValueError, match="non-interactive"): + pyrit_wizard._resolve_scenario_class(context=context, explicit_name=None, interactive=False) + + +def test_resolve_scenario_class_raises_when_registry_empty(): + context = _make_context_with_scenarios({}) + + with pytest.raises(RuntimeError, match="No scenarios are registered"): + pyrit_wizard._resolve_scenario_class(context=context, explicit_name=None, interactive=True) + + +# --- _resolve_target_name ------------------------------------------------------- + + +def test_resolve_target_name_uses_explicit_name(): + name = pyrit_wizard._resolve_target_name(explicit_name="t1", available_targets=["t1", "t2"], interactive=True) + + assert name == "t1" + + +def test_resolve_target_name_rejects_unknown_explicit_name(): + with pytest.raises(ValueError, match="not found"): + pyrit_wizard._resolve_target_name(explicit_name="t9", available_targets=["t1", "t2"], interactive=True) + + +def test_resolve_target_name_interactive_picker(): + with patch("builtins.input", return_value="2"): + name = pyrit_wizard._resolve_target_name(explicit_name=None, available_targets=["one", "two"], interactive=True) + + assert name == "two" + + +def test_resolve_target_name_non_interactive_without_target_raises(): + with pytest.raises(ValueError, match="--non-interactive"): + pyrit_wizard._resolve_target_name(explicit_name=None, available_targets=["t1"], interactive=False) + + +def test_resolve_target_name_empty_registry_raises(): + with pytest.raises(RuntimeError, match="No targets are registered"): + pyrit_wizard._resolve_target_name(explicit_name=None, available_targets=[], interactive=True) + + +# --- _resolve_strategies -------------------------------------------------------- + + +class _FakeStrategy: + """Minimal stand-in for a ScenarioStrategy enum member.""" + + def __init__(self, value: str) -> None: + self.value = value + self.name = value.upper() + + def __eq__(self, other: object) -> bool: + return isinstance(other, _FakeStrategy) and other.value == self.value + + def __hash__(self) -> int: + return hash(self.value) + + +class _FakeStrategyClass: + """Iterable+callable stand-in for a ScenarioStrategy enum class.""" + + def __init__(self, values: list[str]) -> None: + self._members = [_FakeStrategy(v) for v in values] + + def __iter__(self): + return iter(self._members) + + def __call__(self, value: str) -> _FakeStrategy: + for member in self._members: + if member.value == value: + return member + raise ValueError(value) + + +def _make_strategy_class(values: list[str]) -> _FakeStrategyClass: + return _FakeStrategyClass(values) + + +def test_resolve_strategies_returns_none_on_blank_input(): + scenario_class = MagicMock() + scenario_class.__name__ = "Demo" + scenario_class.get_strategy_class.return_value = _make_strategy_class(["a", "b"]) + + with patch("builtins.input", return_value=""): + result = pyrit_wizard._resolve_strategies(scenario_class=scenario_class, interactive=True) + + assert result is None + + +def test_resolve_strategies_returns_none_in_non_interactive_mode(): + scenario_class = MagicMock() + scenario_class.__name__ = "Demo" + + result = pyrit_wizard._resolve_strategies(scenario_class=scenario_class, interactive=False) + + assert result is None + scenario_class.get_strategy_class.assert_not_called() + + +def test_resolve_strategies_picks_by_index_and_name(): + scenario_class = MagicMock() + scenario_class.__name__ = "Demo" + scenario_class.get_strategy_class.return_value = _make_strategy_class(["alpha", "beta", "gamma"]) + + with patch("builtins.input", return_value="1,gamma"): + result = pyrit_wizard._resolve_strategies(scenario_class=scenario_class, interactive=True) + + assert result is not None + assert [member.value for member in result] == ["alpha", "gamma"] + + +def test_resolve_strategies_raises_on_unknown_name(): + scenario_class = MagicMock() + scenario_class.__name__ = "Demo" + scenario_class.get_strategy_class.return_value = _make_strategy_class(["alpha"]) + + with patch("builtins.input", return_value="bogus"): + with pytest.raises(ValueError, match="not recognized"): + pyrit_wizard._resolve_strategies(scenario_class=scenario_class, interactive=True) + + +def test_resolve_strategies_raises_on_out_of_range_index(): + scenario_class = MagicMock() + scenario_class.__name__ = "Demo" + scenario_class.get_strategy_class.return_value = _make_strategy_class(["alpha"]) + + with patch("builtins.input", return_value="9"): + with pytest.raises(ValueError, match="out of range"): + pyrit_wizard._resolve_strategies(scenario_class=scenario_class, interactive=True) + + +# --- _collect_init_inputs ------------------------------------------------------- + + +def test_collect_init_inputs_returns_empty_for_default_schema(): + scenario_class = MagicMock() + scenario_class.input_schema.return_value = [] + + result = pyrit_wizard._collect_init_inputs(scenario_class=scenario_class, interactive=True, max_attempts=5) + + assert result == {} + + +def test_collect_init_inputs_surfaces_opaque_role_with_pointer_to_artifact(): + scenario_class = MagicMock() + scenario_class.__name__ = "OpaqueScenario" + opaque_role = RoleDescriptor( + name="big_obj", + description="pre-built thing", + tag=RoleTag.OPAQUE, + required=True, + ) + scenario_class.input_schema.return_value = [opaque_role] + + fake_collector = MagicMock() + fake_collector.collect.side_effect = OpaqueRoleNotElicitableError("big_obj") + with patch( + "pyrit.scenario.core.input_collector.CliInputCollector", + return_value=fake_collector, + ): + with pytest.raises(RuntimeError, match="--from-artifact"): + pyrit_wizard._collect_init_inputs( + scenario_class=scenario_class, + interactive=True, + max_attempts=5, + ) + + +def test_collect_init_inputs_non_interactive_with_schema_raises(): + scenario_class = MagicMock() + scenario_class.__name__ = "ConfiguredScenario" + scenario_class.input_schema.return_value = [ + RoleDescriptor(name="x", description="x", tag=RoleTag.SCALAR, param_type=int) + ] + + with pytest.raises(ValueError, match="non-interactive"): + pyrit_wizard._collect_init_inputs(scenario_class=scenario_class, interactive=False, max_attempts=5) + + +def test_collect_init_inputs_returns_collected_dict(): + scenario_class = MagicMock() + scenario_class.__name__ = "Demo" + role = RoleDescriptor(name="alpha", description="a", tag=RoleTag.SCALAR, param_type=str) + scenario_class.input_schema.return_value = [role] + + fake_collector = MagicMock() + fake_collector.collect.return_value = "yes" + + with patch( + "pyrit.scenario.core.input_collector.CliInputCollector", + return_value=fake_collector, + ): + result = pyrit_wizard._collect_init_inputs(scenario_class=scenario_class, interactive=True, max_attempts=5) + + assert result == {"alpha": "yes"} + + +# --- _collect_init_async_inputs ------------------------------------------------- + + +def test_collect_init_async_inputs_no_params_returns_overrides_only(): + scenario_class = MagicMock() + scenario_class.supported_parameters.return_value = [] + + result = pyrit_wizard._collect_init_async_inputs( + scenario_class=scenario_class, + interactive=True, + max_attempts=5, + cli_overrides={"max_concurrency": 4, "max_retries": None}, + ) + + assert result == {"max_concurrency": 4} + + +def test_collect_init_async_inputs_overrides_win_over_prompts(): + scenario_class = MagicMock() + scenario_class.__name__ = "Demo" + scenario_class.supported_parameters.return_value = [ + Parameter(name="threshold", description="t", default=1, param_type=int) + ] + + fake_collector = MagicMock() + fake_collector.collect.return_value = 5 + + with patch( + "pyrit.scenario.core.input_collector.CliInputCollector", + return_value=fake_collector, + ): + result = pyrit_wizard._collect_init_async_inputs( + scenario_class=scenario_class, + interactive=True, + max_attempts=5, + cli_overrides={"threshold": 99, "max_concurrency": None}, + ) + + assert result["threshold"] == 99 + + +# --- _print_scenarios_list ------------------------------------------------------ + + +def test_print_scenarios_list_empty(capsys): + registry = MagicMock() + registry.list_metadata.return_value = [] + context = MagicMock() + context.scenario_registry = registry + + rc = pyrit_wizard._print_scenarios_list(context=context) + captured = capsys.readouterr() + + assert rc == 0 + assert "No scenarios are registered." in captured.out + + +def test_print_scenarios_list_renders_first_description_line(capsys): + registry = MagicMock() + item = MagicMock() + item.registry_name = "foo" + item.class_description = "First line.\nSecond line should be hidden." + registry.list_metadata.return_value = [item] + context = MagicMock() + context.scenario_registry = registry + + rc = pyrit_wizard._print_scenarios_list(context=context) + captured = capsys.readouterr() + + assert rc == 0 + assert "foo — First line." in captured.out + assert "Second line should be hidden." not in captured.out + + +# --- _print_targets_list -------------------------------------------------------- + + +def test_print_targets_list_empty(capsys): + rc = pyrit_wizard._print_targets_list(target_names=[]) + captured = capsys.readouterr() + + assert rc == 0 + assert "No targets are registered." in captured.out + + +def test_print_targets_list_lists_names(capsys): + rc = pyrit_wizard._print_targets_list(target_names=["t1", "t2"]) + captured = capsys.readouterr() + + assert rc == 0 + assert "Registered targets (2)" in captured.out + assert "t1" in captured.out + assert "t2" in captured.out + + +# --- main / _run_wizard_async --------------------------------------------------- + + +def test_main_handles_keyboard_interrupt(capsys): + with patch( + "pyrit.cli.pyrit_wizard._run_wizard_async", + new=AsyncMock(side_effect=KeyboardInterrupt()), + ): + rc = pyrit_wizard.main(["--list-scenarios"]) + captured = capsys.readouterr() + + assert rc == 130 + assert "Aborted by user." in captured.out + + +def test_main_returns_one_on_unexpected_exception(capsys): + with patch( + "pyrit.cli.pyrit_wizard._run_wizard_async", + new=AsyncMock(side_effect=RuntimeError("boom")), + ): + rc = pyrit_wizard.main(["--list-scenarios"]) + captured = capsys.readouterr() + + assert rc == 1 + assert "Error: boom" in captured.out + + +def test_main_returns_zero_on_success(): + with patch("pyrit.cli.pyrit_wizard._run_wizard_async", new=AsyncMock(return_value=0)): + rc = pyrit_wizard.main(["--list-scenarios"]) + + assert rc == 0 + + +async def test_run_wizard_async_list_scenarios_short_circuits(capsys): + """--list-scenarios returns before any target / scenario resolution runs.""" + with patch("pyrit.cli.frontend_core.FrontendCore") as fake_frontend_cls: + context = MagicMock() + context.initialize_async = AsyncMock() + registry = MagicMock() + item = MagicMock() + item.registry_name = "foo" + item.class_description = "Foo scenario" + registry.list_metadata.return_value = [item] + context.scenario_registry = registry + fake_frontend_cls.return_value = context + + parsed = pyrit_wizard.parse_args(["--list-scenarios"]) + rc = await pyrit_wizard._run_wizard_async(parsed_args=parsed) + + captured = capsys.readouterr() + assert rc == 0 + assert "foo — Foo scenario" in captured.out + # No target listing or scenario picker triggered. + assert "Available targets" not in captured.out + assert "Wizard:" not in captured.out + + +async def test_run_wizard_async_initialization_script_missing_returns_one(capsys): + with patch( + "pyrit.cli.frontend_core.resolve_initialization_scripts", + side_effect=FileNotFoundError("nope.py not found"), + ): + parsed = pyrit_wizard.parse_args(["--list-scenarios", "--initialization-scripts", "nope.py"]) + rc = await pyrit_wizard._run_wizard_async(parsed_args=parsed) + + captured = capsys.readouterr() + assert rc == 1 + assert "nope.py not found" in captured.out From 9010ff7e775e4b46195776f2815c230eae14b0ca Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 17:35:12 -0700 Subject: [PATCH 31/42] FEAT: Phase 8e pyrit_scan --from-artifact + inverse waterfall --- pyrit/cli/frontend_core.py | 105 +++++++++++++++++ pyrit/cli/pyrit_scan.py | 26 +++++ pyrit/scenario/__init__.py | 4 + pyrit/scenario/core/__init__.py | 4 +- pyrit/scenario/core/waterfall.py | 127 +++++++++++++++++++- tests/unit/cli/test_pyrit_scan.py | 68 +++++++++++ tests/unit/scenario/test_waterfall.py | 160 +++++++++++++++++++++++++- 7 files changed, 486 insertions(+), 8 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index f0a85b24d..e25bc5302 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -526,6 +526,111 @@ async def run_scenario_async( return result +async def run_scenario_from_artifact_async( + *, + artifact_path: Path, + context: FrontendCore, + target_name: str | None = None, + allow_drift: bool = False, + print_summary: bool = True, +) -> ScenarioResult: + """ + Load a graph artifact from disk and run the reconstructed scenario. + + Counterpart to :func:`run_scenario_async` for the Phase 8e replay flow. The + artifact already encodes the scenario class, role inputs, and lifecycle + config; the caller only needs to supply ``target_name`` (env-specific, never + captured in the artifact) and the initializers that register that target. + + Args: + artifact_path: Path to a YAML artifact produced by + :func:`pyrit.scenario.core.graph_artifact.graph_artifact_to_yaml`. + context: PyRIT context with loaded registries. + target_name: Name of a registered target. Required — the artifact does + not capture the target instance. + allow_drift: When ``True``, scenario_version / topology-hash mismatches + are tolerated. Defaults to ``False`` (strict-fail) mirroring the + resume contract. + print_summary: Whether to print the scenario summary after execution. + Defaults to True. + + Returns: + ScenarioResult: The result of the scenario execution. + + Raises: + ValueError: When ``target_name`` is missing or unknown. + GraphArtifactError: On security / drift / opaque-resolution failures. + """ + from pyrit.scenario.core.graph_artifact import ( + graph_artifact_from_yaml, + load_scenario_from_artifact, + ) + + if not context._initialized: + await context.initialize_async() + + initializer_instances = None + if context._initializer_configs: + print(f"Running {len(context._initializer_configs)} initializer(s)...") + sys.stdout.flush() + + initializer_instances = [] + for config in context._initializer_configs: + initializer_class = context.initializer_registry.get_class(config.name) + instance = initializer_class() + if config.args: + instance.set_params_from_args(args=config.args) + initializer_instances.append(instance) + + await initialize_pyrit_async( + memory_db_type=context._database, + initialization_scripts=context._initialization_scripts, + initializers=initializer_instances, + env_files=context._env_files, + silent=getattr(context, "_silent_reinit", False), + ) + + if not target_name: + raise ValueError( + "--target is required when running from an artifact (the artifact " + "does not capture the objective_target instance)." + ) + + target_registry = TargetRegistry.get_registry_singleton() + objective_target = target_registry.get_instance_by_name(target_name) + if objective_target is None: + available_names = target_registry.get_names() + if not available_names: + raise ValueError( + f"Target '{target_name}' not found. The target registry is empty.\n" + "Targets are registered by initializers. Make sure to include an initializer " + "that registers targets (e.g., --initializers target)." + ) + raise ValueError( + f"Target '{target_name}' not found in registry.\nAvailable targets: {', '.join(available_names)}" + ) + + print(f"\nLoading artifact: {artifact_path}") + sys.stdout.flush() + artifact = graph_artifact_from_yaml(artifact_path) + + print(f"Replaying scenario: {artifact.scenario_class_fqn}") + sys.stdout.flush() + scenario = await load_scenario_from_artifact( + artifact, + objective_target=objective_target, + allow_drift=allow_drift, + ) + + result = await scenario.run_async() + + if print_summary: + printer = ConsoleScenarioResultPrinter() + await printer.print_summary_async(result) + + return result + + def _format_wrapped_text(*, text: str, indent: str, width: int = 78) -> str: """ Format text with word wrapping. diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index d85c2235f..fd67f67ed 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -177,6 +177,20 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: help=ARG_HELP["target"], ) + parser.add_argument( + "--from-artifact", + type=Path, + dest="from_artifact", + help="Replay a scenario from a saved graph artifact (YAML produced by pyrit_wizard --save). " + "When set, the positional scenario_name and --strategies are ignored; --target is still required.", + ) + + parser.add_argument( + "--allow-drift", + action="store_true", + help="When loading via --from-artifact, tolerate scenario_version / topology-hash drift.", + ) + return parser @@ -447,6 +461,18 @@ def main(args: Optional[list[str]] = None) -> int: log_level=parsed_args.log_level, ) + # Artifact replay short-circuits the scenario_name / --strategies path. + if parsed_args.from_artifact is not None: + asyncio.run( + frontend_core.run_scenario_from_artifact_async( + artifact_path=parsed_args.from_artifact, + context=context, + target_name=parsed_args.target, + allow_drift=parsed_args.allow_drift, + ) + ) + return 0 + # Resolve the effective scenario name: CLI positional wins, config falls through. config_scenario = context._scenario_config effective_scenario_name = parsed_args.scenario_name or (config_scenario.name if config_scenario else None) diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 62996e74b..657787571 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -53,6 +53,7 @@ collect_inputs_with_retry, discover_input_schema, discover_supported_parameters, + enum_to_spec, graph_artifact_from_yaml, graph_artifact_to_yaml, linear_strategy_policy, @@ -60,6 +61,7 @@ materialize_opaque_inputs, policy_to_spec, spec_to_enum, + spec_to_policy_inputs, validate_init_async_inputs, validate_init_inputs, ) @@ -128,6 +130,7 @@ "collect_inputs_with_retry", "discover_input_schema", "discover_supported_parameters", + "enum_to_spec", "foundry", "garak", "graph_artifact_from_yaml", @@ -137,6 +140,7 @@ "materialize_opaque_inputs", "policy_to_spec", "spec_to_enum", + "spec_to_policy_inputs", "validate_init_async_inputs", "validate_init_inputs", ] diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 22fbe5b4d..eebea0494 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -49,7 +49,7 @@ register_scenario_techniques, ) from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy, linear_strategy_policy -from pyrit.scenario.core.waterfall import policy_to_spec, spec_to_enum +from pyrit.scenario.core.waterfall import enum_to_spec, policy_to_spec, spec_to_enum, spec_to_policy_inputs __all__ = [ "ArtifactInputCollector", @@ -91,6 +91,7 @@ "collect_inputs_with_retry", "discover_input_schema", "discover_supported_parameters", + "enum_to_spec", "get_default_adversarial_target", "get_default_scorer_target", "graph_artifact_from_yaml", @@ -101,6 +102,7 @@ "policy_to_spec", "register_scenario_techniques", "spec_to_enum", + "spec_to_policy_inputs", "validate_init_async_inputs", "validate_init_inputs", ] diff --git a/pyrit/scenario/core/waterfall.py b/pyrit/scenario/core/waterfall.py index 38da62b47..2400b60ad 100644 --- a/pyrit/scenario/core/waterfall.py +++ b/pyrit/scenario/core/waterfall.py @@ -4,7 +4,7 @@ """ Phase 8 waterfall — translation between scenario configuration layers. -Three layers, two forward translations (this module): +Three layers, four translations: * **policy** — a configured :class:`Scenario` with a ready :class:`StrategyGraph` (the executable artifact). @@ -13,7 +13,7 @@ * **strategy enum** — the public ``ScenarioStrategy`` enum members that represent those techniques (the CLI / wizard surface). -The forward direction is lossy but well-defined: +Forward direction (lossy but well-defined): * ``policy_to_spec`` extracts the spec layer. Returns ``[]`` when the scenario does not use the technique registry pattern (e.g. policy-parameterized @@ -21,14 +21,26 @@ * ``spec_to_enum`` resolves specs to ``ScenarioStrategy`` members of the scenario's strategy class. Returns ``None`` when no member matches. -The inverse direction (``enum_to_spec`` + ``spec_to_policy_inputs``) lands in -Phase 8e; both are partial and best-effort. Bugs in any direction stay -localized to this file. +Inverse direction (best-effort, partial; Phase 8e): + +* ``enum_to_spec`` looks up specs for selected enum members via the global + ``AttackTechniqueRegistry`` singleton (``register_scenario_techniques()`` + is called once to populate it). Skips members that are not registered + there — that's the documented "best-effort" semantics. +* ``spec_to_policy_inputs`` returns the rich-object ``__init__`` payload + needed to reconstruct a scenario from a spec list. Returns ``{}`` when + the scenario's ``input_schema()`` is empty or all required roles have + defaults; returns ``None`` for scenarios with required OPAQUE roles + (the inverse waterfall can't materialize closures / instances from + spec metadata — the caller is expected to fall back to + ``pyrit_scan --from-artifact``). + +Bugs in any direction stay localized to this file. """ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from pyrit.registry.object_registries.attack_technique_registry import ( @@ -115,3 +127,106 @@ def spec_to_enum( return None resolved.append(member) return resolved + + +def enum_to_spec(strategies: list[ScenarioStrategy]) -> list[AttackTechniqueSpec]: + """ + Look up :class:`AttackTechniqueSpec` instances for selected strategy enum members. + + Inverse of :func:`spec_to_enum`. Populates the global + :class:`AttackTechniqueRegistry` singleton via + :func:`register_scenario_techniques` (idempotent), then matches each + strategy's ``value`` against the registry's factory catalog. Each factory's + :attr:`AttackTechniqueFactory.source_spec` is the spec we return. + + Best-effort by design: strategies that are not registered in the global + catalog (e.g. members of a per-scenario benchmark catalog like + :class:`AdversarialBenchmark`) are silently skipped. Use + :func:`policy_to_spec` on an instantiated scenario for catalog-specific + spec extraction. + + Args: + strategies (list[ScenarioStrategy]): Selected enum members. + + Returns: + list[AttackTechniqueSpec]: One spec per strategy that resolves through + the global registry, in input order. May be shorter than the input list + when some strategies belong to a per-scenario catalog not in the global + registry. + """ + if not strategies: + return [] + + from pyrit.registry.object_registries.attack_technique_registry import ( + AttackTechniqueRegistry, + ) + from pyrit.scenario.core.scenario_techniques import register_scenario_techniques + + register_scenario_techniques() + factories = AttackTechniqueRegistry.get_registry_singleton().get_factories() + + specs: list[AttackTechniqueSpec] = [] + for strategy in strategies: + factory = factories.get(strategy.value) + if factory is None or factory.source_spec is None: + continue + specs.append(factory.source_spec) + return specs + + +def spec_to_policy_inputs( + scenario_cls: type[Scenario], + specs: list[AttackTechniqueSpec], +) -> dict[str, Any] | None: + """ + Derive the rich-object ``__init__`` payload from a spec list, when possible. + + Inverse-direction half of the waterfall. Designed to be partial: + + * Returns ``{}`` for scenarios whose :meth:`Scenario.input_schema` is empty + or only declares roles with defaults — most first-party scenarios fall + here. The spec list itself becomes ``scenario_strategies`` at + ``initialize_async`` time (via :func:`spec_to_enum`), not a constructor + argument. + * Returns ``None`` for scenarios with required OPAQUE roles + (e.g. :class:`BroadSweepThenDeepDive`'s pre-built ``AtomicAttack`` / + ``OutcomeScorer``). Spec metadata can't materialize a closure or a + bound :class:`Identifiable` instance — the caller is expected to fall + back to ``pyrit_scan --from-artifact path.yaml``. + + Args: + scenario_cls (type[Scenario]): The scenario class to derive inputs for. + specs (list[AttackTechniqueSpec]): The spec catalog the policy uses. + Currently not used to fill in scalar inputs — kept in the signature + so future scenarios that derive constructor kwargs from spec + metadata (e.g. shared seed-technique counts) can do so without a + signature break. + + Returns: + dict[str, Any] | None: Constructor kwargs (often ``{}``), or ``None`` + when the scenario declares required inputs that cannot be reconstructed + from spec metadata alone. + """ + del specs # placeholder: future scenarios may derive scalars from specs + + try: + schema = list(scenario_cls.input_schema()) + except (AttributeError, NotImplementedError): + return {} + + if not schema: + return {} + + from pyrit.scenario.core.input_schema import RoleTag + + for role in schema: + if role.required and role.tag is RoleTag.OPAQUE: + return None + + # Schemas with only optional or non-OPAQUE roles can fall through to + # constructor defaults; the wizard / artifact path supplies anything else. + for role in schema: + if role.required and role.default is None: + return None + + return {} diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index e3c3d42c4..e393de311 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -155,6 +155,25 @@ def test_parse_args_with_list_targets(self): assert args.list_targets is True + def test_parse_args_from_artifact_defaults_to_none(self): + args = pyrit_scan.parse_args(["test_scenario"]) + + assert args.from_artifact is None + assert args.allow_drift is False + + def test_parse_args_with_from_artifact(self): + args = pyrit_scan.parse_args(["--from-artifact", "/tmp/scan.yaml", "--target", "openai"]) + + assert args.from_artifact == Path("/tmp/scan.yaml") + assert args.target == "openai" + assert args.scenario_name is None + + def test_parse_args_with_from_artifact_and_allow_drift(self): + args = pyrit_scan.parse_args(["--from-artifact", "scan.yaml", "--target", "openai", "--allow-drift"]) + + assert args.from_artifact == Path("scan.yaml") + assert args.allow_drift is True + class TestMain: """Tests for main function.""" @@ -271,6 +290,55 @@ def test_main_no_scenario_specified(self, capsys): captured = capsys.readouterr() assert "No scenario specified" in captured.out + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.run_scenario_from_artifact_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_from_artifact_short_circuits_normal_path( + self, + mock_frontend_core: MagicMock, + mock_run_scenario: AsyncMock, + mock_run_from_artifact: AsyncMock, + mock_asyncio_run: MagicMock, + ): + """`--from-artifact` dispatches via the artifact helper, not the normal flow.""" + result = pyrit_scan.main( + ["--from-artifact", "/tmp/scan.yaml", "--target", "openai", "--initializers", "target"] + ) + + assert result == 0 + mock_asyncio_run.assert_called_once() + # asyncio.run was given the artifact coroutine; the normal one was never awaited. + mock_run_scenario.assert_not_called() + mock_run_from_artifact.assert_called_once() + call_kwargs = mock_run_from_artifact.call_args.kwargs + assert call_kwargs["artifact_path"] == Path("/tmp/scan.yaml") + assert call_kwargs["target_name"] == "openai" + assert call_kwargs["allow_drift"] is False + + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.run_scenario_from_artifact_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_from_artifact_propagates_allow_drift( + self, + mock_frontend_core: MagicMock, + mock_run_from_artifact: AsyncMock, + mock_asyncio_run: MagicMock, + ): + result = pyrit_scan.main( + [ + "--from-artifact", + "scan.yaml", + "--target", + "openai", + "--allow-drift", + ] + ) + + assert result == 0 + call_kwargs = mock_run_from_artifact.call_args.kwargs + assert call_kwargs["allow_drift"] is True + @patch("pyrit.cli.pyrit_scan.asyncio.run") @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) @patch("pyrit.cli.frontend_core.FrontendCore") diff --git a/tests/unit/scenario/test_waterfall.py b/tests/unit/scenario/test_waterfall.py index 3d03e1aae..f80c8d87b 100644 --- a/tests/unit/scenario/test_waterfall.py +++ b/tests/unit/scenario/test_waterfall.py @@ -28,9 +28,15 @@ ) from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_strategy import ScenarioStrategy -from pyrit.scenario.core.waterfall import policy_to_spec, spec_to_enum +from pyrit.scenario.core.waterfall import ( + enum_to_spec, + policy_to_spec, + spec_to_enum, + spec_to_policy_inputs, +) from pyrit.score import Scorer # ---------- helpers ---------------------------------------------------------- @@ -303,3 +309,155 @@ def test_every_registered_factory_carries_source_spec(self, monkeypatch: pytest. "policy_to_spec can reconstruct the technique catalog " f"(found {len(missing)} without source_spec: {missing})" ) + + +# ---------- enum_to_spec (inverse) ------------------------------------------- + + +@pytest.mark.usefixtures("patch_central_database") +class TestEnumToSpec: + """Inverse half of the waterfall — enum → spec via the global registry.""" + + @pytest.fixture(autouse=True) + def _reset_registry(self, monkeypatch: pytest.MonkeyPatch): + from pyrit.prompt_target import PromptTarget + from pyrit.registry import TargetRegistry + + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + + adversarial = MagicMock(spec=PromptTarget) + adversarial.get_identifier.return_value = ComponentIdentifier( + class_name="MockAdversarial", class_module="tests.unit.scenario" + ) + monkeypatch.setattr( + "pyrit.scenario.core.scenario_techniques.get_default_adversarial_target", + lambda: adversarial, + ) + + yield + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + + def test_empty_input_returns_empty(self): + assert enum_to_spec([]) == [] + + def test_skips_strategies_not_in_global_registry(self): + # _DummyStrategy.ALPHA is not registered in SCENARIO_TECHNIQUES, so the + # global registry has no factory for "alpha" → silently skipped. + result = enum_to_spec([_DummyStrategy.ALPHA]) + assert result == [] + + def test_resolves_registered_strategies(self): + from pyrit.scenario.scenarios.adaptive.text_adaptive import TextAdaptive + + TextAdaptive._cached_strategy_class = None + strategy_cls = TextAdaptive.get_strategy_class() + leaves = strategy_cls.get_all_strategies()[:2] + assert leaves, "TextAdaptive should register at least one leaf strategy" + + specs = enum_to_spec(leaves) + assert [sp.name for sp in specs] == [m.value for m in leaves] + + +# ---------- spec_to_policy_inputs (inverse) ---------------------------------- + + +class _EmptySchemaScenario(_DummyScenario): + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return [] + + +class _OptionalSchemaScenario(_DummyScenario): + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return [ + RoleDescriptor( + name="alpha", + description="optional scalar", + tag=RoleTag.SCALAR, + param_type=int, + default=3, + required=False, + ) + ] + + +class _RequiredScalarWithDefaultScenario(_DummyScenario): + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return [ + RoleDescriptor( + name="alpha", + description="required scalar with default", + tag=RoleTag.SCALAR, + param_type=int, + default=7, + required=True, + ) + ] + + +class _RequiredScalarNoDefaultScenario(_DummyScenario): + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return [ + RoleDescriptor( + name="alpha", + description="required scalar with no default", + tag=RoleTag.SCALAR, + param_type=int, + required=True, + ) + ] + + +class _RequiredOpaqueScenario(_DummyScenario): + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + return [ + RoleDescriptor( + name="step", + description="required opaque step", + tag=RoleTag.OPAQUE, + required=True, + ) + ] + + +class _SchemaRaisesScenario(_DummyScenario): + @classmethod + def input_schema(cls): + raise NotImplementedError("not yet declared") + + +class TestSpecToPolicyInputs: + """Inverse half — derive ``__init__`` kwargs from a spec list when possible.""" + + def test_returns_empty_when_schema_empty(self): + assert spec_to_policy_inputs(_EmptySchemaScenario, []) == {} + + def test_returns_empty_when_default_schema_inherited(self): + # _DummyScenario inherits the base ``input_schema`` returning []. + assert spec_to_policy_inputs(_DummyScenario, []) == {} + + def test_returns_empty_when_only_optional_roles(self): + assert spec_to_policy_inputs(_OptionalSchemaScenario, []) == {} + + def test_returns_empty_when_required_scalar_has_default(self): + assert spec_to_policy_inputs(_RequiredScalarWithDefaultScenario, []) == {} + + def test_returns_none_when_required_scalar_has_no_default(self): + assert spec_to_policy_inputs(_RequiredScalarNoDefaultScenario, []) is None + + def test_returns_none_when_required_opaque_role_present(self): + assert spec_to_policy_inputs(_RequiredOpaqueScenario, []) is None + + def test_treats_not_implemented_schema_as_empty(self): + assert spec_to_policy_inputs(_SchemaRaisesScenario, []) == {} + + def test_specs_argument_is_currently_ignored(self): + # The placeholder signature accepts a spec list; verify the function + # tolerates arbitrary spec inputs without affecting its return value. + assert spec_to_policy_inputs(_EmptySchemaScenario, [_spec("alpha"), _spec("beta")]) == {} From c5f167c8dd5ec1587f69104c9e62270214aed0e1 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 20 May 2026 17:45:17 -0700 Subject: [PATCH 32/42] DOCS: Phase 8f wizard walkthrough + waterfall round-trip integration tests --- doc/code/scenarios/4_scenario_wizard.ipynb | 286 ++++++++++++++++++ doc/code/scenarios/4_scenario_wizard.py | 185 +++++++++++ doc/myst.yml | 1 + .../scenarios/test_notebooks_scenarios.py | 2 +- tests/unit/scenario/test_waterfall.py | 125 ++++++++ 5 files changed, 598 insertions(+), 1 deletion(-) create mode 100644 doc/code/scenarios/4_scenario_wizard.ipynb create mode 100644 doc/code/scenarios/4_scenario_wizard.py diff --git a/doc/code/scenarios/4_scenario_wizard.ipynb b/doc/code/scenarios/4_scenario_wizard.ipynb new file mode 100644 index 000000000..7f33ec03f --- /dev/null +++ b/doc/code/scenarios/4_scenario_wizard.ipynb @@ -0,0 +1,286 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# The Scenario Wizard\n", + "\n", + "The **wizard** is a thin shell around three lower-level surfaces that make scenarios\n", + "composable from inputs rather than hand-written code:\n", + "\n", + "1. **`Scenario.input_schema()`** — declares the rich-object inputs a scenario's\n", + " `__init__` takes (targets, scorers, opaque steps).\n", + "2. **`Scenario.supported_parameters()`** — declares the scalar inputs\n", + " `initialize_async` takes (dataset config, strategies, concurrency).\n", + "3. **`build_scenario_from_inputs`** — drives both phases from a single dict so the\n", + " same call works from a CLI prompt, a Jupyter notebook, or a saved artifact.\n", + "\n", + "This notebook walks through the wizard from the inside: discovering schemas,\n", + "building a scenario programmatically, capturing it as a **graph artifact** for\n", + "later replay, and reloading the artifact. The interactive CLI (`pyrit_wizard`)\n", + "and the replay command (`pyrit_scan --from-artifact`) are documented at the end." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from tempfile import TemporaryDirectory\n", + "\n", + "from pyrit.registry import ScenarioRegistry, TargetRegistry\n", + "from pyrit.scenario.core import (\n", + " build_graph_artifact,\n", + " build_scenario_from_inputs,\n", + " discover_input_schema,\n", + " discover_supported_parameters,\n", + " graph_artifact_from_yaml,\n", + " graph_artifact_to_yaml,\n", + " load_scenario_from_artifact,\n", + ")\n", + "from pyrit.scenario.scenarios.adaptive import TextAdaptive\n", + "from pyrit.setup import initialize_from_config_async\n", + "\n", + "await initialize_from_config_async(config_path=Path(\"../../scanner/pyrit_conf.yaml\")) # type: ignore\n", + "\n", + "objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name(\"openai_chat\")" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## 1. Discovering what a scenario needs\n", + "\n", + "Before building, ask the scenario class what it expects. The two schemas are\n", + "orthogonal: `input_schema()` covers constructor arguments, `supported_parameters()`\n", + "covers `initialize_async`. Every `Scenario` subclass exposes both — most inherit\n", + "the default empty `input_schema()` and only declare scalars in `supported_parameters`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"TextAdaptive constructor inputs (input_schema):\")\n", + "for role in discover_input_schema(TextAdaptive):\n", + " required = \"required\" if role.required else \"optional\"\n", + " default = f\"default={role.default!r}\" if role.default is not None else \"\"\n", + " print(f\" - {role.name:32s} {role.tag.value:13s} {required:9s} {default}\")\n", + "\n", + "print(\"\\nTextAdaptive initialize_async inputs (supported_parameters):\")\n", + "for param in discover_supported_parameters(TextAdaptive):\n", + " print(f\" - {param.name:32s} default={param.default!r}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "All four `TextAdaptive` input-schema roles are optional with defaults, so the\n", + "wizard can build the scenario from a completely empty input dict. Scenarios that\n", + "declare required `OPAQUE` roles (like `BroadSweepThenDeepDive`) cannot be built\n", + "this way — the CLI wizard rejects them with a hint to use `--from-artifact`,\n", + "and programmatic callers must supply the rich objects directly." + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## 2. Building a scenario from inputs\n", + "\n", + "`build_scenario_from_inputs` constructs the scenario, then runs `initialize_async`.\n", + "It returns a fully initialized scenario ready for `run_async`. The same call shape\n", + "works from any front-end — what changes is who provides the input dicts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "strategy_class = TextAdaptive.get_strategy_class()\n", + "\n", + "scenario = await build_scenario_from_inputs( # type: ignore\n", + " TextAdaptive,\n", + " init_inputs={\n", + " \"epsilon\": 0.3,\n", + " \"max_attempts_per_objective\": 4,\n", + " \"seed\": 42,\n", + " },\n", + " init_async_inputs={\n", + " \"objective_target\": objective_target,\n", + " \"scenario_strategies\": [strategy_class(\"single_turn\")],\n", + " },\n", + ")\n", + "\n", + "print(f\"Built scenario: {scenario.name}\")\n", + "print(f\" scenario_strategies: {[s.value for s in scenario._scenario_strategies]}\")\n", + "print(f\" epsilon: {scenario._epsilon}, max_attempts: {scenario._max_attempts_per_objective}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## 3. Capturing the scenario as a graph artifact\n", + "\n", + "A `GraphArtifact` is a frozen snapshot of everything needed to rebuild the same\n", + "scenario later: the class FQN, the constructor inputs (opaque rich objects encoded\n", + "by `ComponentIdentifier`), the `initialize_async` inputs, and a topology hash for\n", + "drift detection. The objective target is **not** captured — it's environment-specific." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "artifact = build_graph_artifact(scenario)\n", + "print(f\" scenario_class_fqn: {artifact.scenario_class_fqn}\")\n", + "print(f\" scenario_version: {artifact.scenario_version}\")\n", + "print(f\" pyrit_version: {artifact.pyrit_version}\")\n", + "print(f\" topology_hash: {artifact.topology_hash[:12]}…\")\n", + "print(f\" init_async_inputs: {sorted(artifact.init_async_inputs.keys())}\")" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "Serialize to YAML for sharing or version control. The dump is canonical (sorted\n", + "keys, block-style) so equivalent scenarios produce byte-identical artifacts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "with TemporaryDirectory() as tmpdir:\n", + " artifact_path = Path(tmpdir) / \"text_adaptive.yaml\"\n", + " graph_artifact_to_yaml(artifact, artifact_path)\n", + " print(artifact_path.read_text()[:800])" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "## 4. Reloading the artifact\n", + "\n", + "`load_scenario_from_artifact` re-runs the registered scenario class through\n", + "`build_scenario_from_inputs` with the captured inputs, then verifies the\n", + "rebuilt graph's topology hash matches the captured one. Drift fails by default;\n", + "pass `allow_drift=True` to tolerate version or topology changes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "with TemporaryDirectory() as tmpdir:\n", + " artifact_path = Path(tmpdir) / \"text_adaptive.yaml\"\n", + " graph_artifact_to_yaml(artifact, artifact_path)\n", + "\n", + " reloaded_artifact = graph_artifact_from_yaml(artifact_path)\n", + " reloaded_scenario = await load_scenario_from_artifact( # type: ignore\n", + " reloaded_artifact,\n", + " objective_target=objective_target,\n", + " )\n", + "\n", + "print(f\"Reloaded scenario: {reloaded_scenario.name}\")\n", + "print(f\" topology_hash matches: {build_graph_artifact(reloaded_scenario).topology_hash == artifact.topology_hash}\")" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "## 5. The interactive CLI surface\n", + "\n", + "Two CLIs sit on top of the building blocks above:\n", + "\n", + "- **`pyrit_wizard`** — prompts for each role declared by the chosen scenario's\n", + " `input_schema()` and `supported_parameters()`, then either runs the scenario\n", + " (`--run`) or persists it as a graph artifact (`--save path.yaml`). Useful for\n", + " first-time exploration or quick one-offs.\n", + "\n", + "- **`pyrit_scan --from-artifact path.yaml --target `** — loads a previously\n", + " saved artifact and replays it against the named target. Use this in CI or to\n", + " share a reproducible attack with a collaborator. `--allow-drift` tolerates\n", + " scenario_version or topology-hash mismatches.\n", + "\n", + "The wizard cannot elicit `BroadSweepThenDeepDive` (or any scenario with\n", + "required `OPAQUE` roles) directly — its constructor takes pre-built\n", + "`AtomicAttack` instances and closures the CLI cannot construct. For those\n", + "scenarios, build once programmatically (as in section 2 above), save the\n", + "artifact, and replay via `pyrit_scan --from-artifact`." + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Discovery: what scenarios are available?\n", + "\n", + "Scenarios are self-registering via `ScenarioRegistry`. The wizard's `--list`\n", + "flag delegates to the same metadata." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "scenario_registry = ScenarioRegistry.get_registry_singleton()\n", + "for metadata in sorted(scenario_registry.list_metadata(), key=lambda m: m.registry_name)[:8]:\n", + " summary = metadata.class_description.splitlines()[0] if metadata.class_description else \"\"\n", + " print(f\" {metadata.registry_name:36s} {summary[:60]}\")" + ] + } + ], + "metadata": { + "jupytext": { + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/code/scenarios/4_scenario_wizard.py b/doc/code/scenarios/4_scenario_wizard.py new file mode 100644 index 000000000..8645dd78b --- /dev/null +++ b/doc/code/scenarios/4_scenario_wizard.py @@ -0,0 +1,185 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# --- + +# %% [markdown] +# # The Scenario Wizard +# +# The **wizard** is a thin shell around three lower-level surfaces that make scenarios +# composable from inputs rather than hand-written code: +# +# 1. **`Scenario.input_schema()`** — declares the rich-object inputs a scenario's +# `__init__` takes (targets, scorers, opaque steps). +# 2. **`Scenario.supported_parameters()`** — declares the scalar inputs +# `initialize_async` takes (dataset config, strategies, concurrency). +# 3. **`build_scenario_from_inputs`** — drives both phases from a single dict so the +# same call works from a CLI prompt, a Jupyter notebook, or a saved artifact. +# +# This notebook walks through the wizard from the inside: discovering schemas, +# building a scenario programmatically, capturing it as a **graph artifact** for +# later replay, and reloading the artifact. The interactive CLI (`pyrit_wizard`) +# and the replay command (`pyrit_scan --from-artifact`) are documented at the end. + +# %% [markdown] +# ## Setup + +# %% +from pathlib import Path +from tempfile import TemporaryDirectory + +from pyrit.registry import ScenarioRegistry, TargetRegistry +from pyrit.scenario.core import ( + build_graph_artifact, + build_scenario_from_inputs, + discover_input_schema, + discover_supported_parameters, + graph_artifact_from_yaml, + graph_artifact_to_yaml, + load_scenario_from_artifact, +) +from pyrit.scenario.scenarios.adaptive import TextAdaptive +from pyrit.setup import initialize_from_config_async + +await initialize_from_config_async(config_path=Path("../../scanner/pyrit_conf.yaml")) # type: ignore + +objective_target = TargetRegistry.get_registry_singleton().get_instance_by_name("openai_chat") + +# %% [markdown] +# ## 1. Discovering what a scenario needs +# +# Before building, ask the scenario class what it expects. The two schemas are +# orthogonal: `input_schema()` covers constructor arguments, `supported_parameters()` +# covers `initialize_async`. Every `Scenario` subclass exposes both — most inherit +# the default empty `input_schema()` and only declare scalars in `supported_parameters`. + +# %% +print("TextAdaptive constructor inputs (input_schema):") +for role in discover_input_schema(TextAdaptive): + required = "required" if role.required else "optional" + default = f"default={role.default!r}" if role.default is not None else "" + print(f" - {role.name:32s} {role.tag.value:13s} {required:9s} {default}") + +print("\nTextAdaptive initialize_async inputs (supported_parameters):") +for param in discover_supported_parameters(TextAdaptive): + print(f" - {param.name:32s} default={param.default!r}") + +# %% [markdown] +# All four `TextAdaptive` input-schema roles are optional with defaults, so the +# wizard can build the scenario from a completely empty input dict. Scenarios that +# declare required `OPAQUE` roles (like `BroadSweepThenDeepDive`) cannot be built +# this way — the CLI wizard rejects them with a hint to use `--from-artifact`, +# and programmatic callers must supply the rich objects directly. + +# %% [markdown] +# ## 2. Building a scenario from inputs +# +# `build_scenario_from_inputs` constructs the scenario, then runs `initialize_async`. +# It returns a fully initialized scenario ready for `run_async`. The same call shape +# works from any front-end — what changes is who provides the input dicts. + +# %% +strategy_class = TextAdaptive.get_strategy_class() + +scenario = await build_scenario_from_inputs( # type: ignore + TextAdaptive, + init_inputs={ + "epsilon": 0.3, + "max_attempts_per_objective": 4, + "seed": 42, + }, + init_async_inputs={ + "objective_target": objective_target, + "scenario_strategies": [strategy_class("single_turn")], + }, +) + +print(f"Built scenario: {scenario.name}") +print(f" scenario_strategies: {[s.value for s in scenario._scenario_strategies]}") +print(f" epsilon: {scenario._epsilon}, max_attempts: {scenario._max_attempts_per_objective}") + +# %% [markdown] +# ## 3. Capturing the scenario as a graph artifact +# +# A `GraphArtifact` is a frozen snapshot of everything needed to rebuild the same +# scenario later: the class FQN, the constructor inputs (opaque rich objects encoded +# by `ComponentIdentifier`), the `initialize_async` inputs, and a topology hash for +# drift detection. The objective target is **not** captured — it's environment-specific. + +# %% +artifact = build_graph_artifact(scenario) +print(f" scenario_class_fqn: {artifact.scenario_class_fqn}") +print(f" scenario_version: {artifact.scenario_version}") +print(f" pyrit_version: {artifact.pyrit_version}") +print(f" topology_hash: {artifact.topology_hash[:12]}…") +print(f" init_async_inputs: {sorted(artifact.init_async_inputs.keys())}") + +# %% [markdown] +# Serialize to YAML for sharing or version control. The dump is canonical (sorted +# keys, block-style) so equivalent scenarios produce byte-identical artifacts. + +# %% +with TemporaryDirectory() as tmpdir: + artifact_path = Path(tmpdir) / "text_adaptive.yaml" + graph_artifact_to_yaml(artifact, artifact_path) + print(artifact_path.read_text()[:800]) + +# %% [markdown] +# ## 4. Reloading the artifact +# +# `load_scenario_from_artifact` re-runs the registered scenario class through +# `build_scenario_from_inputs` with the captured inputs, then verifies the +# rebuilt graph's topology hash matches the captured one. Drift fails by default; +# pass `allow_drift=True` to tolerate version or topology changes. + +# %% +with TemporaryDirectory() as tmpdir: + artifact_path = Path(tmpdir) / "text_adaptive.yaml" + graph_artifact_to_yaml(artifact, artifact_path) + + reloaded_artifact = graph_artifact_from_yaml(artifact_path) + reloaded_scenario = await load_scenario_from_artifact( # type: ignore + reloaded_artifact, + objective_target=objective_target, + ) + +print(f"Reloaded scenario: {reloaded_scenario.name}") +print(f" topology_hash matches: {build_graph_artifact(reloaded_scenario).topology_hash == artifact.topology_hash}") + +# %% [markdown] +# ## 5. The interactive CLI surface +# +# Two CLIs sit on top of the building blocks above: +# +# - **`pyrit_wizard`** — prompts for each role declared by the chosen scenario's +# `input_schema()` and `supported_parameters()`, then either runs the scenario +# (`--run`) or persists it as a graph artifact (`--save path.yaml`). Useful for +# first-time exploration or quick one-offs. +# +# - **`pyrit_scan --from-artifact path.yaml --target `** — loads a previously +# saved artifact and replays it against the named target. Use this in CI or to +# share a reproducible attack with a collaborator. `--allow-drift` tolerates +# scenario_version or topology-hash mismatches. +# +# The wizard cannot elicit `BroadSweepThenDeepDive` (or any scenario with +# required `OPAQUE` roles) directly — its constructor takes pre-built +# `AtomicAttack` instances and closures the CLI cannot construct. For those +# scenarios, build once programmatically (as in section 2 above), save the +# artifact, and replay via `pyrit_scan --from-artifact`. + +# %% [markdown] +# ## Discovery: what scenarios are available? +# +# Scenarios are self-registering via `ScenarioRegistry`. The wizard's `--list` +# flag delegates to the same metadata. + +# %% +scenario_registry = ScenarioRegistry.get_registry_singleton() +for metadata in sorted(scenario_registry.list_metadata(), key=lambda m: m.registry_name)[:8]: + summary = metadata.class_description.splitlines()[0] if metadata.class_description else "" + print(f" {metadata.registry_name:36s} {summary[:60]}") diff --git a/doc/myst.yml b/doc/myst.yml index 10799bf49..fcd0a05a9 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -170,6 +170,7 @@ project: - file: code/scenarios/1_common_scenario_parameters.ipynb - file: code/scenarios/2_custom_scenario_parameters.ipynb - file: code/scenarios/3_adaptive_scenarios.ipynb + - file: code/scenarios/4_scenario_wizard.ipynb - file: code/registry/0_registry.md children: - file: code/registry/1_class_registry.ipynb diff --git a/tests/integration/scenarios/test_notebooks_scenarios.py b/tests/integration/scenarios/test_notebooks_scenarios.py index b00783c6d..71baca3fa 100644 --- a/tests/integration/scenarios/test_notebooks_scenarios.py +++ b/tests/integration/scenarios/test_notebooks_scenarios.py @@ -12,7 +12,7 @@ nb_directory_path = pathlib.Path(path.DOCS_CODE_PATH, "scenarios").resolve() -skipped_files: list[str] = [] +skipped_files: list[str] = ["4_scenario_wizard.ipynb"] @pytest.mark.parametrize( diff --git a/tests/unit/scenario/test_waterfall.py b/tests/unit/scenario/test_waterfall.py index f80c8d87b..4138411e0 100644 --- a/tests/unit/scenario/test_waterfall.py +++ b/tests/unit/scenario/test_waterfall.py @@ -461,3 +461,128 @@ def test_specs_argument_is_currently_ignored(self): # The placeholder signature accepts a spec list; verify the function # tolerates arbitrary spec inputs without affecting its return value. assert spec_to_policy_inputs(_EmptySchemaScenario, [_spec("alpha"), _spec("beta")]) == {} + + +# ---------- Phase 8f: 4-function round-trip across first-party scenarios ----- + + +@pytest.mark.usefixtures("patch_central_database") +class TestFourFunctionRoundTrip: + """ + Integration coverage that walks every waterfall function in sequence for + the three first-party scenario shapes: + + * **Legacy linear** (``Encoding``) — inherits the default ``input_schema`` + of ``[]``; the scenario is reconstructible from CLI flags alone. + * **Adaptive** (``TextAdaptive``) — declares 4 optional scalar inputs, + all with defaults; ``spec_to_policy_inputs`` returns ``{}``. + * **Policy-parameterized** (``BroadSweepThenDeepDive``) — declares 3 + required OPAQUE roles; ``spec_to_policy_inputs`` is forced to return + ``None`` because closures and bound ``Identifiable`` instances cannot + be reconstructed from spec metadata alone. + + For each scenario we run ``policy_to_spec → spec_to_enum → enum_to_spec`` + and assert the spec catalog is preserved by name through the round trip, + then assert ``spec_to_policy_inputs`` returns the documented value. + """ + + @pytest.fixture(autouse=True) + def _reset(self, monkeypatch: pytest.MonkeyPatch): + from pyrit.prompt_target import PromptTarget + from pyrit.registry import TargetRegistry + + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + + adversarial = MagicMock(spec=PromptTarget) + adversarial.get_identifier.return_value = ComponentIdentifier( + class_name="MockAdversarial", class_module="tests.unit.scenario" + ) + monkeypatch.setattr( + "pyrit.scenario.core.scenario_techniques.get_default_adversarial_target", + lambda: adversarial, + ) + monkeypatch.setattr( + "pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer", + lambda self: _make_scorer_mock(), + ) + + yield + AttackTechniqueRegistry.reset_instance() + TargetRegistry.reset_instance() + + def _round_trip(self, scenario: Scenario, scenario_cls: type[Scenario]): + """policy_to_spec → spec_to_enum → enum_to_spec; assert catalog preserved.""" + specs = policy_to_spec(scenario) + original_names = [sp.name for sp in specs] + + enums = spec_to_enum(scenario_cls, specs) + assert enums is not None, f"{scenario_cls.__name__}: spec_to_enum returned None" + assert [m.value for m in enums] == original_names + + rebuilt = enum_to_spec(enums) + # `enum_to_spec` is best-effort over the global registry; for scenarios + # whose catalog is fully registered (legacy linear, adaptive) this is a + # full round-trip. + assert [sp.name for sp in rebuilt] == original_names + return specs + + def test_encoding_round_trip(self): + """Legacy linear scenario (non-registry catalog): spec catalog is empty by design.""" + from pyrit.scenario.scenarios.garak.encoding import Encoding + + Encoding._cached_strategy_class = None + strategy_cls = Encoding.get_strategy_class() + default = Encoding.get_default_strategy() + resolved = strategy_cls.resolve(None, default=default) + + scenario = Encoding() + scenario._scenario_strategies = resolved + + # Encoding doesn't participate in the AttackTechniqueRegistry catalog + # — its strategies map to per-encoding atomic attacks rather than + # registered techniques. The forward waterfall degrades to empty. + specs = policy_to_spec(scenario) + assert specs == [], "Encoding should expose no registry-catalog specs" + + # spec_to_enum on an empty list is the empty list (not None). + enums = spec_to_enum(Encoding, specs) + assert enums == [] + + rebuilt = enum_to_spec(enums) + assert rebuilt == [] + + # Default schema → no required constructor inputs. + assert spec_to_policy_inputs(Encoding, specs) == {} + + def test_text_adaptive_round_trip(self): + """Adaptive scenario: 4 optional scalars, full registry catalog.""" + from pyrit.scenario.scenarios.adaptive.text_adaptive import TextAdaptive + + TextAdaptive._cached_strategy_class = None + strategy_cls = TextAdaptive.get_strategy_class() + default = TextAdaptive.get_default_strategy() + resolved = strategy_cls.resolve(None, default=default) + + scenario = TextAdaptive() + scenario._scenario_strategies = resolved + + specs = self._round_trip(scenario, TextAdaptive) + assert specs, "TextAdaptive should expose at least one technique spec" + + # All 4 input_schema roles are optional with defaults → returns {}. + assert spec_to_policy_inputs(TextAdaptive, specs) == {} + + def test_broad_sweep_then_deep_dive_returns_none_from_policy_inputs(self): + """Policy-parameterized scenario: opaque-required schema forces ``None``.""" + from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import BroadSweepThenDeepDive + + # The scenario's input_schema declares required OPAQUE roles. We don't + # need to instantiate it to assert this — spec_to_policy_inputs is a + # classmethod-style query over the type. + result = spec_to_policy_inputs(BroadSweepThenDeepDive, []) + assert result is None, ( + "BroadSweepThenDeepDive declares required OPAQUE roles; " + "spec_to_policy_inputs must return None to signal the wizard / CLI " + "should fall back to --from-artifact." + ) From e61df6f39c687766a16f1422f0f8766a1c63c65b Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 11:25:25 -0700 Subject: [PATCH 33/42] FEAT: ScenarioPipeline composition primitive (R5) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/__init__.py | 17 + pyrit/scenario/composite/__init__.py | 34 + pyrit/scenario/composite/scenario_pipeline.py | 744 ++++++++++++++++++ tests/unit/scenario/composite/__init__.py | 2 + .../composite/test_scenario_pipeline.py | 486 ++++++++++++ 5 files changed, 1283 insertions(+) create mode 100644 pyrit/scenario/composite/__init__.py create mode 100644 pyrit/scenario/composite/scenario_pipeline.py create mode 100644 tests/unit/scenario/composite/__init__.py create mode 100644 tests/unit/scenario/composite/test_scenario_pipeline.py diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 657787571..9a3ea03e8 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -17,6 +17,15 @@ from pyrit.common.parameter import Parameter from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult +from pyrit.scenario import composite as _composite_module +from pyrit.scenario.composite import ( + ConditionalPhaseSpec, + PhaseExecution, + PhaseSpec, + PipelineContext, + ScenarioPipeline, + ScenarioPipelineStrategy, +) from pyrit.scenario.core import ( ArtifactInputCollector, AtomicAttack, @@ -87,6 +96,7 @@ benchmark = _benchmark_module garak = _garak_module foundry = _foundry_module +composite = _composite_module __all__ = [ "ArtifactInputCollector", @@ -95,6 +105,7 @@ "AttackTechniqueFactory", "BaselineAttackPolicy", "CliInputCollector", + "ConditionalPhaseSpec", "DatasetConfiguration", "DictInputCollector", "GraphArtifact", @@ -106,6 +117,9 @@ "OpaqueInputUnresolvedError", "OpaqueRoleNotElicitableError", "Parameter", + "PhaseExecution", + "PhaseSpec", + "PipelineContext", "PolicyAction", "RoleDescriptor", "RoleTag", @@ -114,6 +128,8 @@ "ScenarioCoreState", "ScenarioIdentifier", "ScenarioInputValidationError", + "ScenarioPipeline", + "ScenarioPipelineStrategy", "ScenarioResult", "ScenarioStateLike", "ScenarioStep", @@ -128,6 +144,7 @@ "build_scenario_from_inputs", "build_topology_summary", "collect_inputs_with_retry", + "composite", "discover_input_schema", "discover_supported_parameters", "enum_to_spec", diff --git a/pyrit/scenario/composite/__init__.py b/pyrit/scenario/composite/__init__.py new file mode 100644 index 000000000..9c9ae6164 --- /dev/null +++ b/pyrit/scenario/composite/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Composite scenarios — Scenario instances that compose other scenarios. + +The first composite ships in this subpackage is :class:`ScenarioPipeline`, a +linear sequential composer that runs a list of inner scenarios as phases. It +is the direct response to the review suggestion that the simple cross-scenario +composition case should not require authoring a custom branching policy. + +Composite scenarios use :class:`StrategyGraph` internally for orchestration so +they remain consistent with the rest of the scenario core, but their public +API surface (:class:`PhaseSpec`, :class:`ConditionalPhaseSpec`) keeps the +state-machine vocabulary hidden from the caller. +""" + +from pyrit.scenario.composite.scenario_pipeline import ( + ConditionalPhaseSpec, + PhaseExecution, + PhaseSpec, + PipelineContext, + ScenarioPipeline, + ScenarioPipelineStrategy, +) + +__all__ = [ + "ConditionalPhaseSpec", + "PhaseExecution", + "PhaseSpec", + "PipelineContext", + "ScenarioPipeline", + "ScenarioPipelineStrategy", +] diff --git a/pyrit/scenario/composite/scenario_pipeline.py b/pyrit/scenario/composite/scenario_pipeline.py new file mode 100644 index 000000000..7ae6363f2 --- /dev/null +++ b/pyrit/scenario/composite/scenario_pipeline.py @@ -0,0 +1,744 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``ScenarioPipeline`` — sequential composition of inner scenarios as phases. + +Direct response to the review note that the simple cross-scenario composition +case ("run scenario A, then scenario B, then scenario C") should not require +authoring a custom branching policy. ``ScenarioPipeline`` is a thin +``Scenario`` subclass that walks a list of :class:`PhaseSpec` instances in +order, dispatching each phase's inner scenario through its own +``initialize_async`` / ``run_async`` lifecycle. The pipeline itself owns a +linear :class:`StrategyGraph` over integer states so it stays consistent with +the rest of the scenario core, but callers never need to touch the graph +vocabulary — they hand the pipeline a list of phase specs. + +Phase-level conditional skipping is expressed via :class:`ConditionalPhaseSpec`, +whose ``skip_when`` predicate receives a read-only :class:`PipelineContext` +snapshot of prior phase executions. + +Pipeline-level resume is intentionally out of scope for the initial release: +each pipeline run constructs fresh inner scenarios with ``scenario_result_id`` +unset. Pipelines wanting to resume can supply phase factories that recreate +inner scenarios with previously-persisted scenario result ids. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast + +from pyrit.common import apply_defaults +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import Score +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag +from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario +from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult +from pyrit.scenario.core.scenario_strategy import ScenarioStrategy +from pyrit.scenario.core.strategy_graph import ( + PolicyAction, + StrategyGraph, + StrategyPolicy, +) +from pyrit.score.scorer_prompt_validator import ScorerPromptValidator +from pyrit.score.true_false.true_false_scorer import TrueFalseScorer + +if TYPE_CHECKING: + import uuid + from collections.abc import Callable, Mapping, Sequence + + from pyrit.models import MessagePiece + from pyrit.models.scenario_result import ScenarioResult + from pyrit.models.score import ScoreType + from pyrit.scenario.core.atomic_attack import AtomicAttack + from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory + from pyrit.score.scorer import Scorer + +logger = logging.getLogger(__name__) + + +_PHASE_OUTCOME_COMPLETED = "completed" +_PHASE_OUTCOME_SKIPPED = "skipped" + + +@dataclass(frozen=True) +class PhaseExecution: + """ + Snapshot of one phase's execution result. + + Surfaced to :class:`PipelineContext` for :class:`ConditionalPhaseSpec` + predicates and accessible after the pipeline completes via + :attr:`ScenarioPipeline.phase_executions`. + + Attributes: + name (str): The phase's name as declared in :class:`PhaseSpec`. + outcome (str): Either ``"completed"`` (phase ran to completion) or + ``"skipped"`` (a :class:`ConditionalPhaseSpec` predicate elided + execution). Failures propagate as exceptions out of + ``ScenarioPipeline.run_async`` rather than landing here. + scenario_result (ScenarioResult | None): The phase's inner + :class:`ScenarioResult` if the phase ran; ``None`` if the phase + was skipped. + """ + + name: str + outcome: str + scenario_result: Optional[ScenarioResult] = None + + +@dataclass(frozen=True) +class PipelineContext: + """ + Read-only view of prior phase executions, supplied to predicates. + + :class:`ConditionalPhaseSpec.skip_when` receives one of these so predicates + can branch on the outcomes (or full ``ScenarioResult`` payloads) of phases + that already ran. Tuples (not lists) make accidental mutation by the + predicate impossible. + + Attributes: + completed_phase_names (tuple[str, ...]): Names of every phase that has + either completed or been skipped so far, in dispatch order. + completed_phase_outcomes (tuple[str, ...]): The matching outcomes + (``"completed"`` or ``"skipped"``) in the same order. + phase_executions (tuple[PhaseExecution, ...]): Full execution snapshots + including the inner ``ScenarioResult`` for non-skipped phases. + """ + + completed_phase_names: tuple[str, ...] = () + completed_phase_outcomes: tuple[str, ...] = () + phase_executions: tuple[PhaseExecution, ...] = () + + +@dataclass(frozen=True) +class PhaseSpec: + """ + Declarative description of one pipeline phase. + + A phase wraps a single inner ``Scenario`` that is constructed fresh for + each pipeline run via ``scenario_factory``. The pipeline owns + ``initialize_async`` and ``run_async`` for the inner scenario, merging + ``init_async_kwargs`` with auto-injected values (``objective_target`` and + ``memory_labels`` flow from the outer pipeline unless explicitly + overridden in ``init_async_kwargs``). + + Attributes: + name (str): A unique, human-readable label for this phase. Used as + the bucket key in :class:`ScenarioResult.attack_results` so + downstream consumers can pull per-phase results from the + pipeline-level result. + scenario_factory (Callable[[], Scenario]): Zero-arg callable that + builds a fresh, uninitialized inner scenario instance. Called + once per pipeline run when the phase's turn arrives. Use + ``lambda: MyScenario(target=..., scorer=...)`` for the common + case where construction args don't depend on prior phases. + init_async_kwargs (Mapping[str, Any]): Kwargs forwarded verbatim + to the inner scenario's ``initialize_async``. ``objective_target`` + and ``memory_labels`` are auto-injected from the outer pipeline + unless explicitly provided here. + """ + + name: str + scenario_factory: Callable[[], Scenario] + init_async_kwargs: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) + + +@dataclass(frozen=True) +class ConditionalPhaseSpec(PhaseSpec): + """ + Phase that runs only when ``skip_when`` returns False for the current context. + + The predicate is evaluated once at dispatch time, after every prior phase + has either completed or been skipped. Predicates that need access to a + prior phase's ``ScenarioResult`` (e.g. to inspect successful attacks) + should look it up by name from :attr:`PipelineContext.phase_executions`. + + Attributes: + skip_when (Callable[[PipelineContext], bool]): Returns True to skip + this phase, False to run it. Defaults to never skipping (which is + equivalent to using a plain :class:`PhaseSpec`). + """ + + skip_when: Callable[[PipelineContext], bool] = field(default=lambda _ctx: False) + + +class ScenarioPipelineStrategy(ScenarioStrategy): + """ + Single-member strategy enum for :class:`ScenarioPipeline`. + + Pipelines don't expose a technique selection menu — composition is + expressed via the ``phases`` constructor argument. The strategy enum + exists only to satisfy the base :class:`Scenario` contract. + """ + + DEFAULT = ("default", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + """Return the strategy aggregate tags this enum exposes.""" + return {"all"} + + +class _PipelineNoopScorer(TrueFalseScorer): + """ + Stand-in scorer for :class:`ScenarioPipeline` when no pipeline-level scorer is provided. + + The pipeline doesn't directly score anything — each inner scenario brings + its own scorer. But ``Scenario.__init__`` requires a ``Scorer`` for its + identifier graph, so we hand it a deterministic no-op that always emits a + false score. Never actually invoked during pipeline execution because the + pipeline does not call its base ``objective_scorer`` on the phase results. + """ + + def __init__(self) -> None: + super().__init__(validator=ScorerPromptValidator()) + + async def _score_piece_async( + self, + message_piece: MessagePiece, + *, + objective: Optional[str] = None, + ) -> list[Score]: + # ``MessagePiece.id`` is typed Optional but is populated for every persisted + # piece (its default factory mints a UUID). The cast satisfies ``Score.__init__`` + # without paying a runtime ``assert`` that would never fire in practice. + return [ + Score( + score_value="False", + score_value_description="ScenarioPipeline no-op scorer; not used during pipeline execution.", + score_type=cast("ScoreType", "true_false"), + score_rationale="Pipeline does not score attack results directly; inner scenarios own scoring.", + message_piece_id=cast("str | uuid.UUID", message_piece.id), + scorer_class_identifier=self.get_identifier(), + objective=objective, + ) + ] + + def _build_identifier(self) -> ComponentIdentifier: + return ComponentIdentifier.of(self) + + +class _ScenarioPipelinePhaseStep(ScenarioStep): + """ + One pipeline phase, dispatched by :class:`ScenarioPipeline`'s execution graph. + + Holds a :class:`PhaseSpec` and a back-reference to the parent pipeline + (so the predicate can read prior-phase context and the inner scenario can + inherit the pipeline's ``objective_target`` and ``memory_labels``). + Duck-types the :class:`AtomicAttack` attributes the base ``Scenario`` + orchestrator inspects (``atomic_attack_name``, ``display_group``, + ``seed_groups``, ``objectives``, ``technique_eval_hash``, + ``filter_seed_groups_by_objectives``, ``drop_seed_groups_with_hashes``) + so the pipeline can plug into the existing orchestrator without + special-casing branching scenarios. + + The step's ``process_async`` evaluates the conditional predicate (if any), + builds the inner scenario via the factory, initializes it with merged + kwargs, and runs it. The returned :class:`ScenarioStepResult` carries an + empty ``attack_results`` list — inner scenarios persist their own results + against their own ``scenario_result_id``, so re-emitting them here would + duplicate-persist them under the pipeline's own ``scenario_result_id``. + """ + + _OUTPUTS: ClassVar[tuple[str, ...]] = (_PHASE_OUTCOME_COMPLETED, _PHASE_OUTCOME_SKIPPED) + + #: Marker for ``graph_artifact.build_graph_artifact`` (Phase 8g): each phase + #: holds a ``Callable`` scenario factory whose closure cannot be reconstructed + #: from primitive args. Encoding via ``ComponentIdentifier.to_dict()`` is the + #: only sound round-trip path. + GRAPH_ARTIFACT_OPAQUE: ClassVar[bool] = True + + def __init__( + self, + *, + spec: PhaseSpec, + index: int, + pipeline: ScenarioPipeline, + ) -> None: + """ + Initialize the phase step. + + Args: + spec (PhaseSpec): The phase declaration. + index (int): Zero-based position in the pipeline's phase list. + Surfaced into step metadata for diagnostics and used as the + policy state key in the pipeline's linear graph. + pipeline (ScenarioPipeline): Back-reference used to (a) inherit + ``objective_target`` and ``memory_labels`` into the inner + scenario, and (b) feed :class:`PipelineContext` to + :class:`ConditionalPhaseSpec` predicates. + """ + self.name = spec.name + self.outputs = list(self._OUTPUTS) + self._spec = spec + self._index = index + self._pipeline = pipeline + + # Duck-typed AtomicAttack-like attributes so the orchestrator's + # display-group map and resume bookkeeping continue to work without + # special-casing pipeline phases. + self.atomic_attack_name = spec.name + self.display_group = spec.name + self.seed_groups: list[Any] = [] + self.objectives: list[str] = [] + + # ``Scenario._get_completed_objective_hashes_for_attack`` reads + # ``technique_eval_hash`` to scope its rows-by-attribution lookup. + # Pipeline phases have no per-row attribution at the pipeline level + # (inner scenarios persist under their own ids), so the value is a + # stable sentinel that can never collide with a real eval hash. + self.technique_eval_hash: str = f"pipeline-phase::{spec.name}" + + def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) -> None: + """No-op: pipeline phases own no seed groups at the pipeline level.""" + + def drop_seed_groups_with_hashes(self, *, hashes: set[str]) -> None: + """No-op: pipeline phases own no seed groups at the pipeline level.""" + + async def process_async(self) -> ScenarioStepResult: + """ + Run the inner scenario (or skip via ``ConditionalPhaseSpec`` predicate). + + Returns: + ScenarioStepResult: Outcome is ``"completed"`` if the inner + scenario ran, ``"skipped"`` if the predicate elided the + phase. ``attack_results`` is always empty — inner scenarios + own their own persistence. ``metadata`` carries + ``phase_name`` / ``phase_index`` / (when run) + ``inner_scenario_result_id`` for diagnostics and downstream + lookup. + + Raises: + TypeError: If ``scenario_factory`` returns an object that is not + a :class:`Scenario` instance. + """ + spec = self._spec + pipeline = self._pipeline + + if isinstance(spec, ConditionalPhaseSpec): + context = pipeline._snapshot_pipeline_context() + if spec.skip_when(context): + logger.info( + "ScenarioPipeline '%s' phase %d/%d '%s': SKIPPED by predicate", + pipeline._name, + self._index + 1, + len(pipeline._phases), + spec.name, + ) + execution = PhaseExecution(name=spec.name, outcome=_PHASE_OUTCOME_SKIPPED, scenario_result=None) + pipeline._record_phase_execution(execution=execution) + return ScenarioStepResult( + outcome=_PHASE_OUTCOME_SKIPPED, + attack_results=[], + metadata={ + "phase_name": spec.name, + "phase_index": self._index, + "skipped": True, + }, + ) + + logger.info( + "ScenarioPipeline '%s' phase %d/%d '%s': running", + pipeline._name, + self._index + 1, + len(pipeline._phases), + spec.name, + ) + + inner_scenario = spec.scenario_factory() + if not isinstance(inner_scenario, Scenario): + raise TypeError( + f"PhaseSpec(name={spec.name!r}).scenario_factory() returned " + f"{type(inner_scenario).__name__!r}, expected a Scenario instance." + ) + + init_kwargs = dict(spec.init_async_kwargs) + if "objective_target" not in init_kwargs and pipeline._objective_target is not None: + init_kwargs["objective_target"] = pipeline._objective_target + if "memory_labels" not in init_kwargs and pipeline._memory_labels: + init_kwargs["memory_labels"] = dict(pipeline._memory_labels) + + await inner_scenario.initialize_async(**init_kwargs) + inner_result = await inner_scenario.run_async() + + execution = PhaseExecution(name=spec.name, outcome=_PHASE_OUTCOME_COMPLETED, scenario_result=inner_result) + pipeline._record_phase_execution(execution=execution) + + return ScenarioStepResult( + outcome=_PHASE_OUTCOME_COMPLETED, + attack_results=[], + metadata={ + "phase_name": spec.name, + "phase_index": self._index, + "inner_scenario_result_id": str(inner_result.id) if inner_result.id is not None else None, + }, + ) + + def _build_identifier(self) -> ComponentIdentifier: + """ + Build the behavioral identity for this phase step. + + The factory callable is opaque — only the phase name and index are + captured. ``GRAPH_ARTIFACT_OPAQUE = True`` tells artifact serializers + to round-trip via the identifier hash rather than introspecting the + constructor closure. + + Returns: + ComponentIdentifier: The frozen identity snapshot. + """ + return ComponentIdentifier.of( + self, + params={ + "phase_name": self.name, + "phase_index": self._index, + "outputs": list(self.outputs), + }, + ) + + +class ScenarioPipeline(Scenario): + """ + Compose multiple scenarios as sequential phases. + + The pipeline runs each phase's inner scenario in declaration order. Each + inner scenario owns its own dataset, strategies, scorer, and + ``ScenarioResult`` — the pipeline's own ``ScenarioResult`` records the + composition and per-phase outcomes, not the inner ``AttackResult``s + themselves. To inspect per-phase results after a pipeline run, walk + :attr:`phase_executions` and pull ``execution.scenario_result.attack_results`` + for each completed phase. + + Example:: + + pipeline = ScenarioPipeline( + phases=[ + PhaseSpec(name="sweep", scenario_factory=lambda: RapidResponse(...)), + ConditionalPhaseSpec( + name="deep_dive", + scenario_factory=lambda: RapidResponse(...), + init_async_kwargs={"scenario_strategies": [RapidResponseStrategy.ManyShot]}, + skip_when=lambda ctx: "sweep" not in ctx.completed_phase_names, + ), + ], + ) + await pipeline.initialize_async(objective_target=target) + result = await pipeline.run_async() + for execution in pipeline.phase_executions: + print(execution.name, execution.outcome) + + Pipeline-level resume is not supported in this release: pass + ``scenario_result_id=None`` (the default) for every run. Inner scenarios + that need their own resume should manage that themselves inside their + ``scenario_factory``. + """ + + VERSION: ClassVar[int] = 1 + + #: Pipelines compose other scenarios; baseline injection only makes sense + #: per inner scenario, so prevent the base class from prepending one at + #: the pipeline level. Inner scenarios independently respect their own + #: ``BASELINE_ATTACK_POLICY``. + BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden + + @apply_defaults + def __init__( + self, + *, + phases: Sequence[PhaseSpec], + objective_scorer: Optional[Scorer] = None, + name: Optional[str] = None, + scenario_result_id: Optional[str] = None, + ) -> None: + """ + Initialize the pipeline. + + Args: + phases (Sequence[PhaseSpec]): Ordered list of phase declarations. + Must be non-empty. Phase names must be unique across the list + so the per-phase result bucket can be addressed unambiguously. + objective_scorer (Optional[Scorer]): Pipeline-level scorer. Not + used during execution (each inner scenario brings its own + scorer) but recorded on the pipeline's + :class:`ScenarioResult` for downstream introspection. A + deterministic no-op scorer is used when not provided. + name (Optional[str]): Display name for the pipeline. Defaults to + ``"ScenarioPipeline"``. + scenario_result_id (Optional[str]): Must be ``None`` in this + release — pipeline-level resume is not yet supported. + + Raises: + ValueError: If ``phases`` is empty, contains duplicate names, or + ``scenario_result_id`` is supplied. + """ + if not phases: + raise ValueError("ScenarioPipeline requires at least one PhaseSpec.") + + names = [phase.name for phase in phases] + seen: set[str] = set() + duplicates: list[str] = [] + for phase_name in names: + if phase_name in seen and phase_name not in duplicates: + duplicates.append(phase_name) + seen.add(phase_name) + if duplicates: + raise ValueError( + f"PhaseSpec names must be unique within a ScenarioPipeline; duplicates: {sorted(duplicates)!r}" + ) + + if scenario_result_id is not None: + raise ValueError( + "ScenarioPipeline does not support pipeline-level resume yet. " + "Pass scenario_result_id=None and recreate inner scenarios via PhaseSpec.scenario_factory." + ) + + self._phases: list[PhaseSpec] = list(phases) + self._phase_executions: list[PhaseExecution] = [] + + effective_scorer = objective_scorer if objective_scorer is not None else _PipelineNoopScorer() + + super().__init__( + name=name or type(self).__name__, + version=self.VERSION, + strategy_class=ScenarioPipelineStrategy, + objective_scorer=effective_scorer, + ) + + @classmethod + def get_strategy_class(cls) -> type[ScenarioStrategy]: + """Return the (single-member) strategy enum class.""" + return ScenarioPipelineStrategy + + @classmethod + def get_default_strategy(cls) -> ScenarioStrategy: + """Return the only strategy member.""" + return ScenarioPipelineStrategy.DEFAULT + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + """ + Return an empty dataset configuration. + + Each phase's inner scenario brings its own dataset, so the pipeline + has nothing to resolve at the outer level. + + Returns: + DatasetConfiguration: An empty configuration used by the base + ``initialize_async`` when the caller does not supply one. + """ + return DatasetConfiguration() + + @classmethod + def input_schema(cls) -> list[RoleDescriptor]: + """ + Declare the rich-object input the wizard / artifact must capture. + + ``phases`` is opaque because each :class:`PhaseSpec` holds a callable + factory whose closure cannot be reconstructed from primitive args. + Pipelines must be authored programmatically (or round-tripped via + a saved graph artifact). + + Returns: + list[RoleDescriptor]: Two roles — ``phases`` (opaque, required) + and ``name`` (scalar, optional). + """ + return [ + RoleDescriptor( + name="phases", + description=( + "Ordered list of PhaseSpec / ConditionalPhaseSpec instances declaring the inner scenarios " + "to run as pipeline phases. Each holds a callable factory and is therefore opaque." + ), + tag=RoleTag.OPAQUE, + required=True, + ), + RoleDescriptor( + name="name", + description="Display name for the pipeline. Defaults to the class name.", + tag=RoleTag.SCALAR, + param_type=str, + default=None, + required=False, + ), + ] + + @property + def phases(self) -> tuple[PhaseSpec, ...]: + """Return the declared phase specs as an immutable tuple.""" + return tuple(self._phases) + + @property + def phase_executions(self) -> tuple[PhaseExecution, ...]: + """Return snapshots of phase executions from the most recent run.""" + return tuple(self._phase_executions) + + def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: + """ + Return an empty factory map: pipelines do not own techniques directly. + + Each inner scenario owns its own factories; the pipeline-level + registry inspection (used by ``policy_to_spec``) is intentionally + empty so introspection stays side-effect-free at the pipeline layer. + + Returns: + dict[str, AttackTechniqueFactory]: Always empty. + """ + return {} + + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + """ + Materialize one :class:`_ScenarioPipelinePhaseStep` per phase. + + The orchestrator's resume bookkeeping reads ``atomic_attack_name``, + ``display_group``, ``seed_groups``, ``objectives``, and + ``technique_eval_hash`` from each returned step. The phase step + duck-types all of these so the base class can plug pipelines into + the existing execution flow without branching. + + Returns: + list[AtomicAttack]: One step per phase, cast to satisfy the + base contract. Phase steps subclass :class:`ScenarioStep`, + not :class:`AtomicAttack`; dispatch happens through + ``process_async`` via the pipeline's custom execution graph + rather than through ``AtomicAttack.run_async``. + """ + steps: list[ScenarioStep] = [ + _ScenarioPipelinePhaseStep(spec=spec, index=index, pipeline=self) for index, spec in enumerate(self._phases) + ] + return cast("list[AtomicAttack]", steps) + + async def _get_remaining_atomic_attacks_async(self) -> list[AtomicAttack]: + """ + Skip the orchestrator's atomic-attack resume bookkeeping for pipelines. + + Pipeline-level resume is not supported in this release (every run + constructs fresh inner scenarios with no scenario_result_id), so + ``_get_completed_objective_hashes_for_attack`` would return an + empty set for every phase step anyway. Returning ``self._atomic_attacks`` + directly avoids invoking the AtomicAttack-shaped helpers + (``technique_eval_hash`` lookup, ``drop_seed_groups_with_hashes``) + on the pipeline phase steps even though they expose stable stubs + for those attributes. + + Returns: + list[AtomicAttack]: All phase steps, in declaration order. + """ + return self._atomic_attacks + + def _build_execution_graph( + self, + *, + steps: Optional[Sequence[ScenarioStep]] = None, + ) -> StrategyGraph[ScenarioStep, int]: + """ + Build a linear integer-state graph that walks the supplied phase steps. + + Each phase becomes one policy action keyed by its zero-based index. + State ``len(phases)`` is the terminal state. Reset + ``self._phase_executions`` on every invocation so each pipeline run + starts with a clean execution log for :class:`PipelineContext`. + + Args: + steps: Phase steps to dispatch. The base ``_execute_scenario_async`` + passes the resume-filtered ``_atomic_attacks`` list here. + + Returns: + StrategyGraph[ScenarioStep, int]: A linear graph with one state + per phase, terminating after the final phase's action runs. + + Raises: + ValueError: If the resolved step list is empty. + """ + if steps is None: + steps = self._atomic_attacks + + phase_steps = list(steps) + if not phase_steps: + raise ValueError("ScenarioPipeline cannot build an execution graph without phase steps.") + + self._phase_executions = [] + + actions: dict[int, PolicyAction[ScenarioStep, int]] = {} + for index, step in enumerate(phase_steps): + actions[index] = self._build_phase_action(index=index, step=step) + + terminal_state = len(phase_steps) + return StrategyGraph( + policy=StrategyPolicy( + actions=actions, + initial_state=0, + terminal_states=frozenset({terminal_state}), + ), + ) + + def _build_phase_action( + self, + *, + index: int, + step: ScenarioStep, + ) -> PolicyAction[ScenarioStep, int]: + """ + Build the policy action for one phase. + + The action binds the phase step as the current step (so the + orchestrator can resolve ``graph.current_step`` mid-dispatch), + invokes ``step.process_async``, and advances to the next integer + state. The step's metadata (``phase_name``, ``phase_index``) is + merged into the result's metadata so the orchestrator's logging / + step_identifier path sees a consistent step name. + + Args: + index: Zero-based phase index in the linear graph. + step: The phase step to dispatch from this state. + + Returns: + PolicyAction: An async action that runs the phase and transitions + to state ``index + 1``. + """ + next_state = index + 1 + + async def _phase_action( + graph: StrategyGraph[ScenarioStep, int], + ) -> tuple[int, ScenarioStepResult | None]: + graph.bind_current_step(step=step) + try: + base_result = await step.process_async() + merged_metadata = { + "step_name": step.name, + "phase_index": index, + **base_result.metadata, + } + result = ScenarioStepResult( + outcome=base_result.outcome, + attack_results=list(base_result.attack_results), + step_identifier=base_result.step_identifier, + metadata=merged_metadata, + ) + finally: + graph.bind_current_step(step=None) + return next_state, result + + return _phase_action + + def _record_phase_execution(self, *, execution: PhaseExecution) -> None: + """Append a phase execution snapshot to the pipeline's run-scoped log.""" + self._phase_executions.append(execution) + + def _snapshot_pipeline_context(self) -> PipelineContext: + """ + Build the immutable :class:`PipelineContext` view for predicate evaluation. + + Returns: + PipelineContext: A frozen snapshot of every phase execution that + has completed (or been skipped) so far during this pipeline + run. + """ + executions = tuple(self._phase_executions) + return PipelineContext( + completed_phase_names=tuple(execution.name for execution in executions), + completed_phase_outcomes=tuple(execution.outcome for execution in executions), + phase_executions=executions, + ) diff --git a/tests/unit/scenario/composite/__init__.py b/tests/unit/scenario/composite/__init__.py new file mode 100644 index 000000000..9a0454564 --- /dev/null +++ b/tests/unit/scenario/composite/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/scenario/composite/test_scenario_pipeline.py b/tests/unit/scenario/composite/test_scenario_pipeline.py new file mode 100644 index 000000000..dc6021497 --- /dev/null +++ b/tests/unit/scenario/composite/test_scenario_pipeline.py @@ -0,0 +1,486 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for :mod:`pyrit.scenario.composite.scenario_pipeline`. + +Pins both the static surface (dataclass frozenness, input schema shape, +construction validation) and the dynamic behavior (phase dispatch order, +conditional skipping, predicate context, target/labels inheritance) of the +``ScenarioPipeline`` composition primitive. +""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.identifiers import ComponentIdentifier +from pyrit.scenario import ( + ConditionalPhaseSpec, + PhaseExecution, + PhaseSpec, + PipelineContext, + Scenario, + ScenarioPipeline, + ScenarioPipelineStrategy, +) +from pyrit.scenario.composite.scenario_pipeline import ( + _PHASE_OUTCOME_COMPLETED, + _PHASE_OUTCOME_SKIPPED, + _PipelineNoopScorer, + _ScenarioPipelinePhaseStep, +) +from pyrit.scenario.core.input_schema import RoleTag +from pyrit.score import Scorer + +_PIPELINE_SCORER_ID = ComponentIdentifier( + class_name="MockScorer", + class_module="tests.unit.scenario.composite", +) + + +def _make_inner_scenario(name: str = "inner") -> MagicMock: + """Build a ``MagicMock(spec=Scenario)`` that satisfies the pipeline contract.""" + inner = MagicMock(spec=Scenario) + inner.name = name + fake_result = MagicMock() + fake_result.id = f"result-{name}" + inner.initialize_async = AsyncMock() + inner.run_async = AsyncMock(return_value=fake_result) + return inner + + +def _phase_spec( + name: str, + *, + inner: MagicMock | None = None, + init_kwargs: dict[str, Any] | None = None, +) -> tuple[PhaseSpec, MagicMock]: + """Build a ``PhaseSpec`` whose factory returns the supplied (or fresh) inner mock.""" + inner = inner if inner is not None else _make_inner_scenario(name) + spec = PhaseSpec( + name=name, + scenario_factory=lambda: inner, + init_async_kwargs=init_kwargs or {}, + ) + return spec, inner + + +def _conditional_phase_spec( + name: str, + *, + skip_when, + inner: MagicMock | None = None, + init_kwargs: dict[str, Any] | None = None, +) -> tuple[ConditionalPhaseSpec, MagicMock]: + inner = inner if inner is not None else _make_inner_scenario(name) + spec = ConditionalPhaseSpec( + name=name, + scenario_factory=lambda: inner, + init_async_kwargs=init_kwargs or {}, + skip_when=skip_when, + ) + return spec, inner + + +@pytest.fixture +def mock_objective_target(): + target = MagicMock() + target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", + class_module="tests.unit.scenario.composite", + ) + return target + + +@pytest.fixture +def mock_scorer(): + scorer = MagicMock(spec=Scorer) + scorer.get_identifier.return_value = _PIPELINE_SCORER_ID + scorer.get_scorer_metrics.return_value = None + return scorer + + +# --------------------------------------------------------------------------- # +# Static surface +# --------------------------------------------------------------------------- # + + +class TestPhaseSpecDataclasses: + """Dataclass shape: frozen, default factory, conditional inheritance.""" + + def test_phase_spec_is_frozen(self): + spec, _ = _phase_spec("p") + with pytest.raises((AttributeError, Exception)): + spec.name = "renamed" # type: ignore[misc] + + def test_phase_spec_init_async_kwargs_defaults_to_empty(self): + inner = _make_inner_scenario() + spec = PhaseSpec(name="p", scenario_factory=lambda: inner) + assert dict(spec.init_async_kwargs) == {} + + def test_conditional_phase_spec_inherits_phase_spec_fields(self): + spec, _ = _conditional_phase_spec("p", skip_when=lambda _ctx: True) + assert isinstance(spec, PhaseSpec) + assert spec.name == "p" + + def test_conditional_phase_spec_default_skip_when_runs_phase(self): + inner = _make_inner_scenario() + spec = ConditionalPhaseSpec(name="p", scenario_factory=lambda: inner) + # Default skip_when returns False (= run) for any context. + assert spec.skip_when(PipelineContext()) is False + + def test_phase_execution_is_frozen(self): + execution = PhaseExecution(name="p", outcome=_PHASE_OUTCOME_COMPLETED, scenario_result=None) + with pytest.raises((AttributeError, Exception)): + execution.name = "renamed" # type: ignore[misc] + + def test_pipeline_context_is_frozen_with_tuple_defaults(self): + context = PipelineContext() + assert context.completed_phase_names == () + assert context.completed_phase_outcomes == () + assert context.phase_executions == () + with pytest.raises((AttributeError, Exception)): + context.completed_phase_names = ("x",) # type: ignore[misc] + + +class TestScenarioPipelineStrategy: + """The single-member enum that satisfies the base ``Scenario`` contract.""" + + def test_default_member_exists(self): + assert ScenarioPipelineStrategy.DEFAULT.value == "default" + + def test_aggregate_tags_contains_all(self): + assert "all" in ScenarioPipelineStrategy.get_aggregate_tags() + + +# --------------------------------------------------------------------------- # +# Construction validation +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestConstruction: + """Constructor input validation.""" + + def test_rejects_empty_phases(self): + with pytest.raises(ValueError, match="at least one PhaseSpec"): + ScenarioPipeline(phases=[]) + + def test_rejects_duplicate_names(self): + spec_a, _ = _phase_spec("a") + spec_b, _ = _phase_spec("a") + with pytest.raises(ValueError, match="must be unique"): + ScenarioPipeline(phases=[spec_a, spec_b]) + + def test_rejects_scenario_result_id(self): + spec, _ = _phase_spec("a") + with pytest.raises(ValueError, match="does not support pipeline-level resume"): + ScenarioPipeline(phases=[spec], scenario_result_id="some-id") + + def test_constructs_with_single_phase_and_default_scorer(self): + spec, _ = _phase_spec("only") + pipeline = ScenarioPipeline(phases=[spec]) + assert pipeline.phases == (spec,) + # Default scorer is the deterministic no-op stand-in. + assert isinstance(pipeline._objective_scorer, _PipelineNoopScorer) + + def test_constructs_with_explicit_scorer(self, mock_scorer): + spec, _ = _phase_spec("only") + pipeline = ScenarioPipeline(phases=[spec], objective_scorer=mock_scorer) + assert pipeline._objective_scorer is mock_scorer + + def test_name_defaults_to_class_name(self): + spec, _ = _phase_spec("only") + assert ScenarioPipeline(phases=[spec]).name == "ScenarioPipeline" + + def test_name_override_wins(self): + spec, _ = _phase_spec("only") + assert ScenarioPipeline(phases=[spec], name="My Pipeline").name == "My Pipeline" + + +# --------------------------------------------------------------------------- # +# Input schema +# --------------------------------------------------------------------------- # + + +class TestInputSchema: + """``input_schema()`` exposes the two pipeline-level roles.""" + + def test_schema_lists_phases_and_name(self): + schema = ScenarioPipeline.input_schema() + names = [r.name for r in schema] + assert names == ["phases", "name"] + + def test_phases_role_is_opaque_and_required(self): + phases_role = next(r for r in ScenarioPipeline.input_schema() if r.name == "phases") + assert phases_role.tag is RoleTag.OPAQUE + assert phases_role.required is True + + def test_name_role_is_scalar_and_optional(self): + name_role = next(r for r in ScenarioPipeline.input_schema() if r.name == "name") + assert name_role.tag is RoleTag.SCALAR + assert name_role.required is False + + +# --------------------------------------------------------------------------- # +# Phase step (custom ScenarioStep) +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestPhaseStep: + """The custom ``_ScenarioPipelinePhaseStep`` orchestrator-facing surface.""" + + def test_duck_typed_atomic_attack_attributes(self): + spec, _ = _phase_spec("dummy") + pipeline = ScenarioPipeline(phases=[spec]) + step = _ScenarioPipelinePhaseStep(spec=spec, index=0, pipeline=pipeline) + assert step.atomic_attack_name == "dummy" + assert step.display_group == "dummy" + assert step.seed_groups == [] + assert step.objectives == [] + + def test_technique_eval_hash_sentinel_is_phase_scoped(self): + spec, _ = _phase_spec("alpha") + pipeline = ScenarioPipeline(phases=[spec]) + step = _ScenarioPipelinePhaseStep(spec=spec, index=0, pipeline=pipeline) + # The sentinel namespaces by phase name so it can never collide with a real eval hash. + assert step.technique_eval_hash == "pipeline-phase::alpha" + + def test_graph_artifact_opaque_marker_present(self): + # Pipeline phases close over a callable factory; opacity is the correct + # contract for Phase 8g GraphArtifact serialization. + assert _ScenarioPipelinePhaseStep.GRAPH_ARTIFACT_OPAQUE is True + + def test_identifier_is_deterministic(self): + spec, _ = _phase_spec("alpha") + pipeline = ScenarioPipeline(phases=[spec]) + s1 = _ScenarioPipelinePhaseStep(spec=spec, index=0, pipeline=pipeline) + s2 = _ScenarioPipelinePhaseStep(spec=spec, index=0, pipeline=pipeline) + assert s1.get_identifier().hash == s2.get_identifier().hash + + def test_identifier_differs_by_index(self): + spec, _ = _phase_spec("alpha") + pipeline = ScenarioPipeline(phases=[spec]) + s_a = _ScenarioPipelinePhaseStep(spec=spec, index=0, pipeline=pipeline) + s_b = _ScenarioPipelinePhaseStep(spec=spec, index=1, pipeline=pipeline) + assert s_a.get_identifier().hash != s_b.get_identifier().hash + + def test_filter_seed_groups_is_noop(self): + spec, _ = _phase_spec("alpha") + pipeline = ScenarioPipeline(phases=[spec]) + step = _ScenarioPipelinePhaseStep(spec=spec, index=0, pipeline=pipeline) + # No raise; pipeline phase steps own no seed groups at the pipeline level. + step.filter_seed_groups_by_objectives(remaining_objectives=["o1"]) + step.drop_seed_groups_with_hashes(hashes={"h1"}) + + +# --------------------------------------------------------------------------- # +# Execution graph build +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestBuildExecutionGraph: + """Linear int-state graph construction.""" + + async def test_graph_has_one_state_per_phase(self, mock_objective_target): + spec_a, _ = _phase_spec("a") + spec_b, _ = _phase_spec("b") + spec_c, _ = _phase_spec("c") + pipeline = ScenarioPipeline(phases=[spec_a, spec_b, spec_c]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + graph = pipeline._build_execution_graph() + assert graph.policy.initial_state == 0 + assert graph.policy.terminal_states == frozenset({3}) + for state in range(3): + assert callable(graph.policy.get_action(state=state)) + + async def test_explicit_steps_override_atomic_attacks(self, mock_objective_target): + spec_a, _ = _phase_spec("a") + spec_b, _ = _phase_spec("b") + pipeline = ScenarioPipeline(phases=[spec_a, spec_b]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + single_step = _ScenarioPipelinePhaseStep(spec=spec_a, index=0, pipeline=pipeline) + graph = pipeline._build_execution_graph(steps=[single_step]) + assert graph.policy.terminal_states == frozenset({1}) + + async def test_build_raises_with_empty_steps(self, mock_objective_target): + spec, _ = _phase_spec("a") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + with pytest.raises(ValueError, match="cannot build an execution graph"): + pipeline._build_execution_graph(steps=[]) + + async def test_build_resets_phase_executions_log(self, mock_objective_target): + spec, _ = _phase_spec("a") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + pipeline._phase_executions.append( + PhaseExecution(name="stale", outcome=_PHASE_OUTCOME_COMPLETED, scenario_result=None) + ) + pipeline._build_execution_graph() + assert pipeline._phase_executions == [] + + +# --------------------------------------------------------------------------- # +# End-to-end execution +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestPipelineExecution: + """End-to-end ``run_async`` behavior with mocked inner scenarios.""" + + async def test_single_phase_runs_and_records_execution(self, mock_objective_target): + spec, inner = _phase_spec("only") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + await pipeline.run_async() + + inner.initialize_async.assert_awaited_once() + inner.run_async.assert_awaited_once() + assert len(pipeline.phase_executions) == 1 + execution = pipeline.phase_executions[0] + assert execution.name == "only" + assert execution.outcome == _PHASE_OUTCOME_COMPLETED + assert execution.scenario_result is inner.run_async.return_value + + async def test_two_phases_run_in_declaration_order(self, mock_objective_target): + order: list[str] = [] + + def _make_recording_inner(name: str): + inner = _make_inner_scenario(name) + + async def _record_init(**_kwargs): + order.append(f"init:{name}") + + async def _record_run(): + order.append(f"run:{name}") + return inner.run_async.return_value + + inner.initialize_async.side_effect = _record_init + inner.run_async.side_effect = _record_run + return inner + + inner_a = _make_recording_inner("a") + inner_b = _make_recording_inner("b") + spec_a, _ = _phase_spec("a", inner=inner_a) + spec_b, _ = _phase_spec("b", inner=inner_b) + + pipeline = ScenarioPipeline(phases=[spec_a, spec_b]) + await pipeline.initialize_async(objective_target=mock_objective_target) + await pipeline.run_async() + + assert order == ["init:a", "run:a", "init:b", "run:b"] + assert [e.name for e in pipeline.phase_executions] == ["a", "b"] + assert all(e.outcome == _PHASE_OUTCOME_COMPLETED for e in pipeline.phase_executions) + + async def test_conditional_phase_skipped_does_not_call_factory(self, mock_objective_target): + factory_calls: list[str] = [] + inner = _make_inner_scenario("skipme") + + def _factory(): + factory_calls.append("called") + return inner + + spec = ConditionalPhaseSpec( + name="skipme", + scenario_factory=_factory, + skip_when=lambda _ctx: True, + ) + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + await pipeline.run_async() + + assert factory_calls == [] + inner.initialize_async.assert_not_called() + inner.run_async.assert_not_called() + assert len(pipeline.phase_executions) == 1 + assert pipeline.phase_executions[0].outcome == _PHASE_OUTCOME_SKIPPED + assert pipeline.phase_executions[0].scenario_result is None + + async def test_conditional_phase_runs_when_predicate_returns_false(self, mock_objective_target): + spec, inner = _conditional_phase_spec("runme", skip_when=lambda _ctx: False) + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + await pipeline.run_async() + + inner.run_async.assert_awaited_once() + assert pipeline.phase_executions[0].outcome == _PHASE_OUTCOME_COMPLETED + + async def test_predicate_receives_context_of_prior_phases(self, mock_objective_target): + captured: list[PipelineContext] = [] + + spec_a, _ = _phase_spec("a") + spec_b, _ = _conditional_phase_spec( + "b", + skip_when=lambda ctx: (captured.append(ctx), False)[1], + ) + spec_c, _ = _conditional_phase_spec( + "c", + skip_when=lambda ctx: (captured.append(ctx), True)[1], + ) + + pipeline = ScenarioPipeline(phases=[spec_a, spec_b, spec_c]) + await pipeline.initialize_async(objective_target=mock_objective_target) + await pipeline.run_async() + + assert len(captured) == 2 + # When phase b evaluates, only phase a has executed. + ctx_b = captured[0] + assert ctx_b.completed_phase_names == ("a",) + assert ctx_b.completed_phase_outcomes == (_PHASE_OUTCOME_COMPLETED,) + # When phase c evaluates, phases a and b have both executed (b ran). + ctx_c = captured[1] + assert ctx_c.completed_phase_names == ("a", "b") + assert ctx_c.completed_phase_outcomes == (_PHASE_OUTCOME_COMPLETED, _PHASE_OUTCOME_COMPLETED) + + async def test_pipeline_objective_target_flows_to_inner_scenario(self, mock_objective_target): + spec, inner = _phase_spec("p") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + await pipeline.run_async() + + inner.initialize_async.assert_awaited_once() + call_kwargs = inner.initialize_async.await_args.kwargs + assert call_kwargs["objective_target"] is mock_objective_target + + async def test_pipeline_memory_labels_flow_to_inner_scenario(self, mock_objective_target): + spec, inner = _phase_spec("p") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async( + objective_target=mock_objective_target, + memory_labels={"campaign": "alpha"}, + ) + await pipeline.run_async() + + call_kwargs = inner.initialize_async.await_args.kwargs + assert call_kwargs["memory_labels"] == {"campaign": "alpha"} + + async def test_init_async_kwargs_override_pipeline_target(self, mock_objective_target): + custom_target = MagicMock() + custom_target.get_identifier.return_value = ComponentIdentifier( + class_name="OverrideTarget", + class_module="tests.unit.scenario.composite", + ) + spec, inner = _phase_spec("p", init_kwargs={"objective_target": custom_target}) + + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + await pipeline.run_async() + + call_kwargs = inner.initialize_async.await_args.kwargs + assert call_kwargs["objective_target"] is custom_target + + async def test_factory_returning_non_scenario_raises_type_error(self, mock_objective_target): + spec = PhaseSpec(name="bad", scenario_factory=lambda: "not a scenario") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + with pytest.raises(TypeError, match="'bad'"): + await pipeline.run_async() From eb2be7f9ddd4d9135e3bea49b5e730d0ffe41a0c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 11:31:20 -0700 Subject: [PATCH 34/42] pin ScenarioPipeline scenario_factory guard uses isinstance not duck-typing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../composite/test_scenario_pipeline.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/unit/scenario/composite/test_scenario_pipeline.py b/tests/unit/scenario/composite/test_scenario_pipeline.py index dc6021497..361fcdda5 100644 --- a/tests/unit/scenario/composite/test_scenario_pipeline.py +++ b/tests/unit/scenario/composite/test_scenario_pipeline.py @@ -484,3 +484,34 @@ async def test_factory_returning_non_scenario_raises_type_error(self, mock_objec await pipeline.initialize_async(objective_target=mock_objective_target) with pytest.raises(TypeError, match="'bad'"): await pipeline.run_async() + + async def test_factory_returning_duck_typed_non_scenario_class_raises_type_error(self, mock_objective_target): + """Pin that the guard uses :func:`isinstance`, not duck-typing. + + A class exposing ``initialize_async``/``run_async`` methods but not + inheriting from :class:`Scenario` must be rejected. This guards against + a future refactor that swaps ``isinstance(inner_scenario, Scenario)`` + for an attribute-presence check, which would silently accept duck-typed + objects whose ``run_async`` returns an incompatible result shape. + """ + + class _DuckScenario: + async def initialize_async(self, **kwargs: Any) -> None: + pass + + async def run_async(self, **kwargs: Any) -> Any: + return None + + spec = PhaseSpec(name="duck", scenario_factory=lambda: _DuckScenario()) + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + with pytest.raises(TypeError, match="'duck'"): + await pipeline.run_async() + + async def test_factory_returning_none_raises_type_error(self, mock_objective_target): + """A factory that returns ``None`` is a common typo (missing ``return``).""" + spec = PhaseSpec(name="forgot_return", scenario_factory=lambda: None) + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + with pytest.raises(TypeError, match="'forgot_return'"): + await pipeline.run_async() From 3be99e4a4325a584e54e3711a8c0a8931ac61815 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 11:41:17 -0700 Subject: [PATCH 35/42] pin ScenarioPipeline graph artifact round-trip not supported in v1 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/composite/scenario_pipeline.py | 8 ++- .../composite/test_scenario_pipeline.py | 68 +++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/pyrit/scenario/composite/scenario_pipeline.py b/pyrit/scenario/composite/scenario_pipeline.py index 7ae6363f2..2f6159d00 100644 --- a/pyrit/scenario/composite/scenario_pipeline.py +++ b/pyrit/scenario/composite/scenario_pipeline.py @@ -538,8 +538,12 @@ def input_schema(cls) -> list[RoleDescriptor]: ``phases`` is opaque because each :class:`PhaseSpec` holds a callable factory whose closure cannot be reconstructed from primitive args. - Pipelines must be authored programmatically (or round-tripped via - a saved graph artifact). + Pipelines must be authored programmatically — round-tripping a + pipeline via a saved graph artifact is not supported in v1, because + neither the ``Callable`` factory nor ``PhaseSpec.init_async_kwargs`` + (a :class:`types.MappingProxyType`) is YAML-serializable. See + :class:`tests.unit.scenario.composite.test_scenario_pipeline.TestArtifactRoundTripNotSupported` + for the regression pin. Returns: list[RoleDescriptor]: Two roles — ``phases`` (opaque, required) diff --git a/tests/unit/scenario/composite/test_scenario_pipeline.py b/tests/unit/scenario/composite/test_scenario_pipeline.py index 361fcdda5..cb93d5891 100644 --- a/tests/unit/scenario/composite/test_scenario_pipeline.py +++ b/tests/unit/scenario/composite/test_scenario_pipeline.py @@ -515,3 +515,71 @@ async def test_factory_returning_none_raises_type_error(self, mock_objective_tar await pipeline.initialize_async(objective_target=mock_objective_target) with pytest.raises(TypeError, match="'forgot_return'"): await pipeline.run_async() + + +# --------------------------------------------------------------------------- # +# Artifact round-trip limitation (v1 contract pin) +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestArtifactRoundTripNotSupported: + """Pin the v1 contract that pipelines cannot round-trip via graph artifacts. + + A :class:`PhaseSpec` holds a ``Callable`` scenario factory and a + :class:`types.MappingProxyType` ``init_async_kwargs`` (when defaulted) — + neither survives YAML serialization. The pipeline ``input_schema`` + docstring states this explicitly; these tests pin both failure modes in + code so that any future change to add round-trip support must update the + contract intentionally (delete these tests, not flip them to + ``assert success``). + + Related finding for the graph_artifact owner: ``_encode_init_inputs`` + silently passes through OPAQUE values that lack ``get_identifier`` (see + the "Unknown opaque shape — defer to caller serialization at their own + risk." branch), which is what lets pipelines reach the (broken) + serialization path at all. A fail-loud guard there would raise a clearer + error than either downstream failure mode below. + """ + + async def test_default_init_async_kwargs_mappingproxy_breaks_asdict(self, mock_objective_target): + """PhaseSpec defaults ``init_async_kwargs`` to a MappingProxyType, which + ``dataclasses.asdict`` cannot ``deepcopy``. ``graph_artifact_to_yaml`` + starts with ``asdict(artifact)`` so this is the first failure point. + """ + from dataclasses import asdict + + from pyrit.scenario.core.graph_artifact import build_graph_artifact + + inner = _make_inner_scenario("p1") + spec = PhaseSpec(name="p1", scenario_factory=lambda: inner) + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + artifact = build_graph_artifact(pipeline, init_inputs={"phases": [spec]}) + assert isinstance(artifact.init_inputs["phases"], list) + assert isinstance(artifact.init_inputs["phases"][0], PhaseSpec) + + with pytest.raises(TypeError, match="mappingproxy"): + asdict(artifact) + + async def test_callable_factory_breaks_yaml_dump(self, mock_objective_target, tmp_path): + """Even when ``init_async_kwargs`` is supplied as a plain dict so + ``asdict`` succeeds, YAML's ``SafeDumper`` cannot represent the + Callable scenario factory. + """ + import yaml + + from pyrit.scenario.core.graph_artifact import ( + build_graph_artifact, + graph_artifact_to_yaml, + ) + + spec, _ = _phase_spec("p1") # init_kwargs={} avoids the MappingProxy path + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + artifact = build_graph_artifact(pipeline, init_inputs={"phases": [spec]}) + out_path = tmp_path / "pipeline.yaml" + with pytest.raises(yaml.representer.RepresenterError, match="cannot represent"): + graph_artifact_to_yaml(artifact, out_path) From c35a77c2ea45e830097653067711d5c9beebd5f4 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 11:54:10 -0700 Subject: [PATCH 36/42] pin BroadSweepThenDeepDive graph artifact round-trip not supported in v1 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenarios/airt/sweep_then_deep_dive.py | 22 +++- .../airt/test_sweep_then_deep_dive.py | 113 ++++++++++++++++++ 2 files changed, 130 insertions(+), 5 deletions(-) diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index 8f072185f..4af7ca026 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -536,11 +536,23 @@ def input_schema(cls) -> list[RoleDescriptor]: """ Declare the rich-object and scalar inputs the wizard / artifact must capture. - The three opaque roles are pre-built ``Identifiable`` instances that the - CLI wizard cannot elicit directly — programmatic callers must supply them - and CLI flows must round-trip them through a saved graph artifact (see - :class:`pyrit.scenario.core.graph_artifact.GraphArtifact`). The single - scalar role (``weakness_label``) is freely elicitable. + The three opaque roles are pre-built instances that the CLI wizard + cannot elicit directly — programmatic callers must supply them. + Round-tripping a fully-built ``BroadSweepThenDeepDive`` through a + saved graph artifact is **not supported in v1**: + + * ``sweep_atomic_attack`` is a single ``Identifiable`` and would + encode cleanly, but + * ``deep_dive_atomic_attacks`` is a ``list`` (no ``get_identifier``) + and ``outcome_scorer`` is an :class:`OutcomeScorer` (no + ``get_identifier``). Both fall into the "Unknown opaque shape — + defer to caller serialization at their own risk" branch of + :func:`pyrit.scenario.core.graph_artifact._encode_init_inputs` and + break ``yaml.safe_dump``. See + :class:`tests.unit.scenario.scenarios.airt.test_sweep_then_deep_dive.TestArtifactRoundTripNotSupported` + for the regression pin. + + The single scalar role (``weakness_label``) is freely elicitable. Returns: list[RoleDescriptor]: Three OPAQUE roles plus one SCALAR. diff --git a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py index a0854d871..e17f79456 100644 --- a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py +++ b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive.py @@ -691,3 +691,116 @@ async def test_consecutive_runs_reset_state(self) -> None: assert scenario._weak_categories == set() async for _ in graph2.event_loop_async(): pass + + +# --------------------------------------------------------------------------- # +# Artifact round-trip limitation (v1 contract pin) +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestArtifactRoundTripNotSupported: + """Pin the v1 contract that ``BroadSweepThenDeepDive`` cannot round-trip via graph artifacts. + + Two of the three OPAQUE input roles bypass identifier encoding: + + * ``deep_dive_atomic_attacks`` is a ``list``. The list itself has no + ``get_identifier``, so :func:`_encode_init_inputs` falls through to + its "Unknown opaque shape — defer to caller serialization at their + own risk" branch and stores the raw list of ``AtomicAttack`` instances + verbatim. ``yaml.SafeDumper`` then raises ``RepresenterError`` on the + first element. + * ``outcome_scorer`` is an :class:`OutcomeScorer` instance which has no + ``get_identifier`` method (it's a wrapper around a scorer, not an + Identifiable). Same fall-through; same downstream failure if YAML + reached it. + + Only ``sweep_atomic_attack`` (a single AtomicAttack with + ``get_identifier``) round-trips correctly. This class pins the v1 + failure mode so any future change adding proper round-trip support must + update the contract intentionally (delete this class, not flip it to + ``assert success``). + + Related: the same gap manifests in :class:`ScenarioPipeline`'s + ``phases`` role (pinned by + ``tests.unit.scenario.composite.test_scenario_pipeline.TestArtifactRoundTripNotSupported``). + The right fix is a recursive container encoder in + :func:`_encode_init_inputs` plus a fail-loud ``OpaqueInputUnserializableError`` + at leaf level; tracked as a follow-up to PR #1767. + """ + + @staticmethod + def _build_scenario_and_inputs() -> tuple[BroadSweepThenDeepDive, dict[str, Any], MagicMock]: + wrapped_scorer = MagicMock(spec=Scorer) + wrapped_scorer.get_identifier.return_value = ComponentIdentifier( + class_name="MockScorer", + class_module="tests.unit.scenario.scenarios.airt", + ) + scorer = OutcomeScorer( + wrapped_scorer=wrapped_scorer, + outcome_map={ + _WEAKNESS_LABEL: lambda s: s.score_value == _WEAKNESS_LABEL, + _SAFE_LABEL: lambda s: s.score_value == _SAFE_LABEL, + }, + ) + sweep = _make_atomic_mock(name="sweep", display_group="cat-a", attack_results=[]) + deep = _make_atomic_mock(name="deep-0", display_group="cat-a", attack_results=[]) + scenario = BroadSweepThenDeepDive( + sweep_atomic_attack=cast("AtomicAttack", sweep), + deep_dive_atomic_attacks=[cast("AtomicAttack", deep)], + outcome_scorer=scorer, + weakness_label=_WEAKNESS_LABEL, + ) + target = MagicMock() + target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", + class_module="tests.unit.scenario.scenarios.airt", + ) + init_inputs: dict[str, Any] = { + "sweep_atomic_attack": sweep, + "deep_dive_atomic_attacks": [deep], + "outcome_scorer": scorer, + "weakness_label": _WEAKNESS_LABEL, + } + return scenario, init_inputs, target + + async def test_list_of_atomic_attacks_passes_through_then_yaml_chokes(self, tmp_path) -> None: + """``deep_dive_atomic_attacks`` (a list) skips identifier encoding and + breaks YAML serialization on the contained ``AtomicAttack`` mocks. + """ + import yaml + + from pyrit.scenario.core.graph_artifact import ( + build_graph_artifact, + graph_artifact_to_yaml, + ) + + scenario, init_inputs, target = self._build_scenario_and_inputs() + await scenario.initialize_async(objective_target=target) + + artifact = build_graph_artifact(scenario, init_inputs=init_inputs) + # Sanity check: the single-instance role IS encoded as an identifier dict, + # but the list-of-instances role is passed through as a live list. + assert isinstance(artifact.init_inputs["sweep_atomic_attack"], dict) + assert isinstance(artifact.init_inputs["deep_dive_atomic_attacks"], list) + + out_path = tmp_path / "bstd.yaml" + with pytest.raises(yaml.representer.RepresenterError, match="cannot represent"): + graph_artifact_to_yaml(artifact, out_path) + + async def test_outcome_scorer_lacks_get_identifier_and_passes_through(self) -> None: + """``OutcomeScorer`` is a wrapper without ``get_identifier``, so + ``_encode_init_inputs`` stores it verbatim instead of as an identifier + dict. This is the second OPAQUE role that would break YAML round-trip + even if the ``deep_dive_atomic_attacks`` list issue were fixed. + """ + from pyrit.scenario.core.graph_artifact import build_graph_artifact + + scenario, init_inputs, target = self._build_scenario_and_inputs() + await scenario.initialize_async(objective_target=target) + + artifact = build_graph_artifact(scenario, init_inputs=init_inputs) + stored = artifact.init_inputs["outcome_scorer"] + # Pinned passthrough: the live OutcomeScorer instance is what got stored. + assert isinstance(stored, OutcomeScorer) + assert not hasattr(stored, "get_identifier") From 75506afa552e1dedf6aeb43e6b0066b7821ca949 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 12:07:39 -0700 Subject: [PATCH 37/42] MAINT: honest ScenarioPipeline persistence docstring + R1 forward-compat stub Two duck-driven follow-ups to e61df6f3 (R5 ScenarioPipeline): - Class docstring rewritten to be explicit that per-phase outcomes live only on in-memory self._phase_executions in v1, not in the persisted outer ScenarioResult. The previous wording implied cross-process readers could inspect per-phase outcomes; they cannot until R5.1 wires phase_executions into metadata. - _ScenarioPipelinePhaseStep.set_scenario_result_id added as a no-op stub. Today the base orchestrator's isinstance(_step, AtomicAttack) guard makes this unreachable, but R1 plans to collapse that guard and dispatch uniformly via process_async. Any non-AtomicAttack ScenarioStep needs this method or R1 will break with AttributeError. Regression test pins the contract. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/composite/scenario_pipeline.py | 37 ++++++++++++++++--- .../composite/test_scenario_pipeline.py | 19 ++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/pyrit/scenario/composite/scenario_pipeline.py b/pyrit/scenario/composite/scenario_pipeline.py index 2f6159d00..1b9827928 100644 --- a/pyrit/scenario/composite/scenario_pipeline.py +++ b/pyrit/scenario/composite/scenario_pipeline.py @@ -299,6 +299,20 @@ def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) - def drop_seed_groups_with_hashes(self, *, hashes: set[str]) -> None: """No-op: pipeline phases own no seed groups at the pipeline level.""" + def set_scenario_result_id(self, scenario_result_id: str | None) -> None: + """ + No-op: pipeline phases don't get the outer pipeline's ``scenario_result_id``. + + Inner scenarios persist under their own ``scenario_result_id``s, so + the pipeline never propagates the outer id down to the phase step. + The base ``Scenario._execute_scenario_async`` guards this call with + ``isinstance(_step, AtomicAttack)`` today (phase steps don't subclass + ``AtomicAttack``), so this stub is currently unreachable from the + orchestrator. It exists for forward-compatibility with the planned + R1 follow-up that collapses the isinstance branch and dispatches + uniformly via ``process_async``. + """ + async def process_async(self) -> ScenarioStepResult: """ Run the inner scenario (or skip via ``ConditionalPhaseSpec`` predicate). @@ -406,11 +420,24 @@ class ScenarioPipeline(Scenario): The pipeline runs each phase's inner scenario in declaration order. Each inner scenario owns its own dataset, strategies, scorer, and - ``ScenarioResult`` — the pipeline's own ``ScenarioResult`` records the - composition and per-phase outcomes, not the inner ``AttackResult``s - themselves. To inspect per-phase results after a pipeline run, walk - :attr:`phase_executions` and pull ``execution.scenario_result.attack_results`` - for each completed phase. + ``ScenarioResult`` (with its own ``scenario_result_id`` persisted + independently). + + The pipeline's own ``ScenarioResult`` records only the **composition** — + each phase name appears as a key in ``attack_results``, but every bucket + is empty because the inner scenarios persist their own ``AttackResult``s + against their own ``scenario_result_id``s and re-emitting them under the + pipeline's id would duplicate-persist them. + + **Per-phase outcomes are NOT persisted in v1.** They live only on + :attr:`phase_executions`, an in-memory log scoped to the live pipeline + instance. To inspect per-phase outcomes after a pipeline run, you must + keep a reference to the pipeline object and walk + ``pipeline.phase_executions``; reading the persisted outer + ``ScenarioResult`` back via ``get_scenario_results`` from another process + will show only the composition (empty phase buckets), not the outcomes. + See ``tests.unit.scenario.composite.test_scenario_pipeline.TestPipelineExecution`` + for the in-process inspection pattern. Example:: diff --git a/tests/unit/scenario/composite/test_scenario_pipeline.py b/tests/unit/scenario/composite/test_scenario_pipeline.py index cb93d5891..7bc50fcd6 100644 --- a/tests/unit/scenario/composite/test_scenario_pipeline.py +++ b/tests/unit/scenario/composite/test_scenario_pipeline.py @@ -276,6 +276,25 @@ def test_filter_seed_groups_is_noop(self): step.filter_seed_groups_by_objectives(remaining_objectives=["o1"]) step.drop_seed_groups_with_hashes(hashes={"h1"}) + def test_set_scenario_result_id_is_noop_stub(self): + """Forward-compatibility with R1 (drop-isinstance refactor). + + Today the base ``Scenario._execute_scenario_async`` guards + ``set_scenario_result_id`` with ``isinstance(_step, AtomicAttack)``, + so this method is unreachable from the orchestrator. R1 plans to + collapse that branch and dispatch uniformly via ``process_async``; + when it does, every non-AtomicAttack ``ScenarioStep`` must expose + ``set_scenario_result_id``. This regression pins the stub so that + contract change doesn't silently introduce an ``AttributeError`` + for pipeline phase steps. + """ + spec, _ = _phase_spec("alpha") + pipeline = ScenarioPipeline(phases=[spec]) + step = _ScenarioPipelinePhaseStep(spec=spec, index=0, pipeline=pipeline) + # Stub should accept either a real id or None without raising. + step.set_scenario_result_id("any-id") + step.set_scenario_result_id(None) + # --------------------------------------------------------------------------- # # Execution graph build From 308002b9bd9adf7adc57ce99c7315ffe2334494b Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 12:52:45 -0700 Subject: [PATCH 38/42] MAINT: R5.1 ScenarioPipeline polish (persistence + merge order + failure modes) Closes the R5.1 rubber-duck follow-ups on top of R5 (ScenarioPipeline): - Add Scenario._finalize_scenario_result_async base hook (no-op default) called once between the last successful step and the COMPLETED state transition, giving composition subclasses a place to write run-summary state into ScenarioResult.metadata. - Override the hook on ScenarioPipeline to persist per-phase outcomes as metadata['phase_executions'] (a list of name/outcome/inner_scenario_result_id dicts), so cross-process readers can reload the pipeline result and walk phases without holding a live pipeline instance. Class docstring updated to reflect the new persistence contract. - Invert metadata merge order in _build_phase_action: pipeline-stamped diagnostic keys (step_name, phase_index) now win over inner-step result metadata. Regression test pins the inversion against a NoisyStep that emits colliding keys. - Document PipelineContext immutability nuance: structurally frozen at the dataclass level, but inner ScenarioResult payloads are not deep- immutable and should be treated as read-only by convention. - Sharpen input_schema docstring on the kept-but-broken 'phases' role: explicit guidance that the OPAQUE tag is an authoring-refusal signal for the wizard until pipelines can round-trip. - Add TestPipelineFailureModes covering Duck #1's M1 gaps: inner initialize_async / run_async exceptions, predicate exceptions, and partial-progress phase_executions on mid-flight failure. 50/50 composite tests pass (was 41 at R5 ship); 1066/1066 scenario tests pass overall. Pre-commit clean (ruff format/check, ty). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/composite/scenario_pipeline.py | 112 +++++++-- pyrit/scenario/core/scenario.py | 26 ++ .../composite/test_scenario_pipeline.py | 225 ++++++++++++++++++ 3 files changed, 347 insertions(+), 16 deletions(-) diff --git a/pyrit/scenario/composite/scenario_pipeline.py b/pyrit/scenario/composite/scenario_pipeline.py index 1b9827928..067dd539c 100644 --- a/pyrit/scenario/composite/scenario_pipeline.py +++ b/pyrit/scenario/composite/scenario_pipeline.py @@ -100,6 +100,16 @@ class PipelineContext: that already ran. Tuples (not lists) make accidental mutation by the predicate impossible. + .. note:: + + "Read-only" here means *structurally* read-only — the dataclass itself + is frozen and the outer containers are tuples. The :class:`ScenarioResult` + payloads referenced from :attr:`phase_executions` are **not** deep-frozen, + so a predicate that reaches into ``execution.scenario_result.attack_results`` + and mutates the underlying object will affect the persisted result. Treat + inner objects as read-only by convention; don't rely on them being + immutable at the type-system level. + Attributes: completed_phase_names (tuple[str, ...]): Names of every phase that has either completed or been skipped so far, in dispatch order. @@ -423,21 +433,24 @@ class ScenarioPipeline(Scenario): ``ScenarioResult`` (with its own ``scenario_result_id`` persisted independently). - The pipeline's own ``ScenarioResult`` records only the **composition** — - each phase name appears as a key in ``attack_results``, but every bucket - is empty because the inner scenarios persist their own ``AttackResult``s - against their own ``scenario_result_id``s and re-emitting them under the - pipeline's id would duplicate-persist them. - - **Per-phase outcomes are NOT persisted in v1.** They live only on - :attr:`phase_executions`, an in-memory log scoped to the live pipeline - instance. To inspect per-phase outcomes after a pipeline run, you must - keep a reference to the pipeline object and walk - ``pipeline.phase_executions``; reading the persisted outer - ``ScenarioResult`` back via ``get_scenario_results`` from another process - will show only the composition (empty phase buckets), not the outcomes. - See ``tests.unit.scenario.composite.test_scenario_pipeline.TestPipelineExecution`` - for the in-process inspection pattern. + The pipeline's own ``ScenarioResult`` records the **composition** plus + a per-phase outcome summary. Each phase name appears as a key in + ``attack_results`` but every bucket is empty (inner scenarios persist + their own ``AttackResult``s against their own ``scenario_result_id``s + and re-emitting them under the pipeline's id would duplicate-persist + them). The per-phase outcome summary lands in + ``ScenarioResult.metadata["phase_executions"]`` as a list of + ``{"name", "outcome", "inner_scenario_result_id"}`` dicts so + cross-process readers can reload the pipeline result and pull + inner-scenario results by id without holding a reference to the live + pipeline instance. + + Per-phase outcomes are *also* surfaced in-process via + :attr:`phase_executions` for the duration of the live pipeline object; + those snapshots carry the full inner :class:`ScenarioResult` instances + (not just ids) for ergonomic in-process inspection. See + ``tests.unit.scenario.composite.test_scenario_pipeline.TestPipelineExecution`` + for both inspection patterns. Example:: @@ -572,6 +585,17 @@ def input_schema(cls) -> list[RoleDescriptor]: :class:`tests.unit.scenario.composite.test_scenario_pipeline.TestArtifactRoundTripNotSupported` for the regression pin. + The ``phases`` role is **kept** in the schema (despite the broken + artifact path) so the wizard and other introspection consumers can + discover that the pipeline needs a phase list at all. Authoring + consumers (the wizard in particular) should treat the OPAQUE tag as + a hard refusal signal: "this scenario cannot be assembled from + wizard-elicited primitives; the user must supply a programmatic + pipeline instance directly." When the recursive-encoder follow-up + lands and pipelines can round-trip module-level factories, the + schema declaration becomes useful for read-back as well — until + then, it serves as authoring-intent documentation. + Returns: list[RoleDescriptor]: Two roles — ``phases`` (opaque, required) and ``name`` (scalar, optional). @@ -737,10 +761,16 @@ async def _phase_action( graph.bind_current_step(step=step) try: base_result = await step.process_async() + # Pipeline diagnostic keys (``step_name``, ``phase_index``) must + # win over any same-named keys the inner step result may carry, + # so downstream loggers and step_identifier consumers always see + # the pipeline's view of the phase, not whatever the inner step + # decided to stamp. Spread base_result.metadata first, pipeline + # keys last. merged_metadata = { + **base_result.metadata, "step_name": step.name, "phase_index": index, - **base_result.metadata, } result = ScenarioStepResult( outcome=base_result.outcome, @@ -758,6 +788,56 @@ def _record_phase_execution(self, *, execution: PhaseExecution) -> None: """Append a phase execution snapshot to the pipeline's run-scoped log.""" self._phase_executions.append(execution) + async def _finalize_scenario_result_async(self, *, scenario_result_id: str) -> None: + """ + Persist per-phase outcomes into the outer ``ScenarioResult.metadata``. + + The pipeline's :attr:`phase_executions` log is otherwise in-memory + only, scoped to the live ``ScenarioPipeline`` instance. Without this + hook, cross-process readers (anything that reloads the + ``ScenarioResult`` via ``get_scenario_results`` from another process + or after the pipeline object has been garbage-collected) would see + only the composition (empty phase buckets in ``attack_results``) and + no per-phase outcome record. The snapshot is written under the + ``"phase_executions"`` metadata key as a list of + ``{"name": str, "outcome": str, "inner_scenario_result_id": str | None}`` + dicts. Inner scenarios' full ``AttackResult`` rows continue to live + on their own ``scenario_result_id``s; this snapshot only records + which phases ran, what they returned, and how to find the inner + result if there was one. + + Merges with (does not replace) any prior metadata (e.g. the + ``objective_hashes`` that the base ``_build_initial_scenario_metadata`` + may have written at construction time). + + Args: + scenario_result_id (str): The id of the pipeline's + :class:`ScenarioResult`, supplied by the orchestrator. + """ + snapshot: list[dict[str, Any]] = [ + { + "name": execution.name, + "outcome": execution.outcome, + "inner_scenario_result_id": ( + str(execution.scenario_result.id) + if execution.scenario_result is not None and execution.scenario_result.id is not None + else None + ), + } + for execution in self._phase_executions + ] + + current = self._memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + existing_metadata: dict[str, Any] = {} + if current: + existing_metadata = dict(current[0].metadata or {}) + existing_metadata["phase_executions"] = snapshot + + self._memory.update_scenario_metadata( + scenario_result_id=scenario_result_id, + metadata=existing_metadata, + ) + def _snapshot_pipeline_context(self) -> PipelineContext: """ Build the immutable :class:`PipelineContext` view for predicate evaluation. diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index e09c09f2c..9c6d29a2e 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -843,6 +843,27 @@ def _build_initial_scenario_metadata(self) -> dict[str, Any]: metadata["objective_hashes"] = hashes return metadata + async def _finalize_scenario_result_async(self, *, scenario_result_id: str) -> None: + """ + Persist any run-summary state to the scenario result before COMPLETED. + + Called once per successful execution attempt of ``_execute_scenario_async``, + right after the final step completes and before the + ``update_scenario_run_state(COMPLETED)`` transition lands. Subclasses + that need to record run-summary state (e.g. composition pipelines + writing per-phase outcomes into ``ScenarioResult.metadata``) should + override this method. + + The default is a no-op. The ``scenario_result_id`` is supplied so + subclasses don't need to re-derive it from ``self._scenario_result_id``. + + Args: + scenario_result_id (str): The id of the scenario result that is + about to be marked COMPLETED. Use + ``self._memory.update_scenario_metadata`` to write into it. + """ + return + def _apply_persisted_objectives(self, *, stored_result: ScenarioResult) -> None: """ On resume, replay the originally-sampled objective subset. @@ -1579,6 +1600,11 @@ async def _execute_scenario_async(self) -> ScenarioResult: logger.info(f"Scenario '{self._name}' completed successfully") + # Give subclasses a chance to persist run-summary state on the + # ScenarioResult (e.g. composition pipelines writing per-phase + # outcomes into metadata) just before the COMPLETED transition. + await self._finalize_scenario_result_async(scenario_result_id=scenario_result_id) + # Mark scenario as completed self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="COMPLETED" diff --git a/tests/unit/scenario/composite/test_scenario_pipeline.py b/tests/unit/scenario/composite/test_scenario_pipeline.py index 7bc50fcd6..13b38241c 100644 --- a/tests/unit/scenario/composite/test_scenario_pipeline.py +++ b/tests/unit/scenario/composite/test_scenario_pipeline.py @@ -536,6 +536,87 @@ async def test_factory_returning_none_raises_type_error(self, mock_objective_tar await pipeline.run_async() +# --------------------------------------------------------------------------- # +# Failure modes — exception propagation from inner lifecycle and predicates +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestPipelineFailureModes: + """Pin how exceptions raised by inner-scenario lifecycle and predicates surface. + + Covers the gaps Duck #1 (gpt-5.3-codex) flagged in the post-R5 review: + inner ``initialize_async`` failures, inner ``run_async`` failures, and + predicate failures. All three should propagate as exceptions out of + ``pipeline.run_async()`` after marking the scenario FAILED in memory, not + be silently caught. + """ + + async def test_inner_initialize_async_exception_propagates(self, mock_objective_target): + inner = _make_inner_scenario("blowup") + inner.initialize_async.side_effect = RuntimeError("init exploded") + spec = PhaseSpec(name="blowup", scenario_factory=lambda: inner) + + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + with pytest.raises(RuntimeError, match="init exploded"): + await pipeline.run_async() + # run_async should never have been reached. + inner.run_async.assert_not_called() + + async def test_inner_run_async_exception_propagates(self, mock_objective_target): + inner = _make_inner_scenario("midflight") + inner.run_async.side_effect = RuntimeError("run exploded") + spec = PhaseSpec(name="midflight", scenario_factory=lambda: inner) + + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + with pytest.raises(RuntimeError, match="run exploded"): + await pipeline.run_async() + # initialize_async ran, run_async raised. + inner.initialize_async.assert_awaited_once() + + async def test_predicate_exception_propagates(self, mock_objective_target): + inner = _make_inner_scenario("downstream") + + spec_a, inner_a = _phase_spec("a") + spec_b = ConditionalPhaseSpec( + name="b_broken_predicate", + scenario_factory=lambda: inner, + skip_when=lambda _ctx: (_ for _ in ()).throw(ValueError("predicate logic bug")), + ) + + pipeline = ScenarioPipeline(phases=[spec_a, spec_b]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + with pytest.raises(ValueError, match="predicate logic bug"): + await pipeline.run_async() + # Phase a completed before the broken predicate was evaluated. + inner_a.run_async.assert_awaited_once() + # The broken-predicate phase's factory was never called. + inner.run_async.assert_not_called() + + async def test_failure_in_later_phase_records_earlier_phases_executions(self, mock_objective_target): + """When phase N fails, phase_executions for phases 1..N-1 should still be recorded.""" + spec_a, inner_a = _phase_spec("alpha") + inner_b = _make_inner_scenario("beta") + inner_b.run_async.side_effect = RuntimeError("phase 2 down") + spec_b = PhaseSpec(name="beta", scenario_factory=lambda: inner_b) + + pipeline = ScenarioPipeline(phases=[spec_a, spec_b]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + with pytest.raises(RuntimeError, match="phase 2 down"): + await pipeline.run_async() + + # The in-memory log captures the completed phase before the crash. + assert len(pipeline.phase_executions) == 1 + assert pipeline.phase_executions[0].name == "alpha" + assert pipeline.phase_executions[0].outcome == _PHASE_OUTCOME_COMPLETED + + # --------------------------------------------------------------------------- # # Artifact round-trip limitation (v1 contract pin) # --------------------------------------------------------------------------- # @@ -602,3 +683,147 @@ async def test_callable_factory_breaks_yaml_dump(self, mock_objective_target, tm out_path = tmp_path / "pipeline.yaml" with pytest.raises(yaml.representer.RepresenterError, match="cannot represent"): graph_artifact_to_yaml(artifact, out_path) + + +# --------------------------------------------------------------------------- # +# Metadata merge order (pipeline keys always win) +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestPhaseActionMetadataMergeOrder: + """Pin that pipeline-stamped metadata wins over inner step result metadata. + + The orchestrator's logging and ``step_identifier`` consumers downstream of + ``_build_phase_action`` always need the pipeline's view of + ``step_name`` / ``phase_index``, not whatever the inner step decided to + stamp. A misbehaving inner that emits ``"step_name"`` in its result + metadata must not override the pipeline's stamp. + """ + + async def test_pipeline_keys_override_inner_step_metadata(self, mock_objective_target): + from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult + + class _NoisyStep(ScenarioStep): + """A step whose process_async result carries colliding metadata keys.""" + + def __init__(self, name: str) -> None: + self.name = name + self.outputs = ["done"] + + async def process_async(self) -> ScenarioStepResult: + return ScenarioStepResult( + outcome="done", + attack_results=[], + metadata={ + "step_name": "INNER_HIJACK", + "phase_index": 999, + "innocuous": "inner_value", + }, + ) + + def _build_identifier(self) -> ComponentIdentifier: + return ComponentIdentifier.of(self, params={"name": self.name}) + + spec, _ = _phase_spec("pipeline_phase_zero") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + noisy = _NoisyStep(name="pipeline_phase_zero") + action = pipeline._build_phase_action(index=0, step=noisy) + graph = pipeline._build_execution_graph(steps=[noisy]) + + next_state, result = await action(graph) + + assert next_state == 1 + assert result is not None + # Pipeline keys win: + assert result.metadata["step_name"] == "pipeline_phase_zero" + assert result.metadata["phase_index"] == 0 + # Innocuous inner keys still pass through: + assert result.metadata["innocuous"] == "inner_value" + + +# --------------------------------------------------------------------------- # +# Persisted ScenarioResult.metadata["phase_executions"] +# --------------------------------------------------------------------------- # + + +@pytest.mark.usefixtures("patch_central_database") +class TestPhaseExecutionsPersistence: + """Pin that per-phase outcomes are written into the outer ScenarioResult.metadata. + + Without this persistence, cross-process readers of the pipeline's + ``ScenarioResult`` (anything reloading via ``get_scenario_results`` + without holding a live pipeline reference) would see only the composition + (empty phase buckets) and no per-phase outcome record. The persisted + snapshot is a list of ``{"name", "outcome", "inner_scenario_result_id"}`` + dicts keyed under ``metadata["phase_executions"]``. + """ + + async def test_phase_executions_persisted_in_scenario_result_metadata(self, mock_objective_target): + from pyrit.memory import CentralMemory + + spec_a, inner_a = _phase_spec("alpha") + spec_b, inner_b = _phase_spec("beta") + + pipeline = ScenarioPipeline(phases=[spec_a, spec_b]) + await pipeline.initialize_async(objective_target=mock_objective_target) + result = await pipeline.run_async() + + memory = CentralMemory.get_memory_instance() + [persisted] = memory.get_scenario_results(scenario_result_ids=[str(result.id)]) + snapshot = persisted.metadata["phase_executions"] + + assert isinstance(snapshot, list) + assert len(snapshot) == 2 + assert snapshot[0]["name"] == "alpha" + assert snapshot[0]["outcome"] == _PHASE_OUTCOME_COMPLETED + assert snapshot[0]["inner_scenario_result_id"] == str(inner_a.run_async.return_value.id) + assert snapshot[1]["name"] == "beta" + assert snapshot[1]["outcome"] == _PHASE_OUTCOME_COMPLETED + assert snapshot[1]["inner_scenario_result_id"] == str(inner_b.run_async.return_value.id) + + async def test_skipped_phase_persists_with_null_inner_result_id(self, mock_objective_target): + from pyrit.memory import CentralMemory + + spec_a, _ = _phase_spec("a") + spec_b, inner_b = _conditional_phase_spec("b", skip_when=lambda _ctx: True) + + pipeline = ScenarioPipeline(phases=[spec_a, spec_b]) + await pipeline.initialize_async(objective_target=mock_objective_target) + result = await pipeline.run_async() + + memory = CentralMemory.get_memory_instance() + [persisted] = memory.get_scenario_results(scenario_result_ids=[str(result.id)]) + snapshot = persisted.metadata["phase_executions"] + + assert snapshot[1]["name"] == "b" + assert snapshot[1]["outcome"] == _PHASE_OUTCOME_SKIPPED + assert snapshot[1]["inner_scenario_result_id"] is None + # Predicate elided execution, factory should never have been called. + inner_b.run_async.assert_not_called() + + async def test_finalize_hook_preserves_existing_metadata_keys(self, mock_objective_target): + """The finalize override merges with prior metadata, doesn't replace it.""" + from pyrit.memory import CentralMemory + + spec, _ = _phase_spec("only") + pipeline = ScenarioPipeline(phases=[spec]) + await pipeline.initialize_async(objective_target=mock_objective_target) + + # Simulate a prior metadata entry (e.g. what the base _build_initial_scenario_metadata + # would write when max_dataset_size is set). + memory = CentralMemory.get_memory_instance() + memory.update_scenario_metadata( + scenario_result_id=str(pipeline._scenario_result_id), + metadata={"objective_hashes": ["sentinel-pre-existing-hash"]}, + ) + + result = await pipeline.run_async() + [persisted] = memory.get_scenario_results(scenario_result_ids=[str(result.id)]) + + # Both keys must coexist; the finalize hook merges, doesn't replace. + assert persisted.metadata["objective_hashes"] == ["sentinel-pre-existing-hash"] + assert "phase_executions" in persisted.metadata + assert persisted.metadata["phase_executions"][0]["name"] == "only" From 06e8257744a03a467dd19b802ee3800afe87b3af Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 13:15:20 -0700 Subject: [PATCH 39/42] MAINT: R1 collapse adaptive override into base linear policy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Unify scenario step dispatch on `ScenarioStep.process_async` so the base linear policy handles AtomicAttack, AdaptiveStep, and any future ScenarioStep subclass through one code path. Adds a setter on `AtomicAttack` so the base policy can push the scenario-level `max_concurrency` into atomic steps without the orchestrator special-casing step types. Introduces `LinearScenario` as the L0 authoring tier so users can construct a scenario from a list of pre-built steps without subclassing. - `AtomicAttack.set_scenario_max_concurrency` + `_scenario_max_concurrency` instance state, with `process_async` honoring the bound value when delegating to `run_async`. - `Scenario._build_default_linear_policy` now pushes max_concurrency into every `AtomicAttack` step before the action loop and always dispatches via `process_async` (removes the isinstance branch that forced AdaptiveStep authors into L2). - `AdaptiveScenario._build_execution_graph` and `_build_adaptive_linear_policy` (~80 LOC) deleted; the base linear policy now drives adaptive correctly because outcomes propagate verbatim from `AdaptiveStep.process_async`. - `LinearScenario(steps=[...], objective_scorer=...)` returns a runnable scenario with zero subclassing — the L0 entry point sketched in the R1 plan response to rlundeen's PR #1767 review. Test fixture pattern: `MagicMock(spec=AtomicAttack)` AsyncMock fallback for `process_async` returns coroutines that fail metadata unpacking. Five test fixtures updated to wire `process_async` to delegate to `run_async` so existing `run_async.assert_called_with` assertions continue to work through the new dispatch chain. New tests cover the setter validation, process_async max_concurrency forwarding, and end-to-end LinearScenario execution. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/__init__.py | 4 + pyrit/scenario/core/__init__.py | 3 + pyrit/scenario/core/atomic_attack.py | 42 +++- pyrit/scenario/core/linear_scenario.py | 161 +++++++++++++ pyrit/scenario/core/scenario.py | 65 +++--- .../scenarios/adaptive/adaptive_scenario.py | 94 +------- .../scenario/core/test_linear_scenario.py | 215 ++++++++++++++++++ .../scenarios/adaptive/test_text_adaptive.py | 4 +- .../test_atomic_attack_scenario_step.py | 87 +++++++ tests/unit/scenario/test_scenario.py | 49 +++- .../scenario/test_scenario_graph_execution.py | 34 ++- .../scenario/test_scenario_partial_results.py | 24 ++ tests/unit/scenario/test_scenario_retry.py | 24 ++ 13 files changed, 666 insertions(+), 140 deletions(-) create mode 100644 pyrit/scenario/core/linear_scenario.py create mode 100644 tests/unit/scenario/core/test_linear_scenario.py diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 9a3ea03e8..0421df9cd 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -40,6 +40,8 @@ GraphArtifactError, GraphArtifactSecurityError, InputCollector, + LinearScenario, + LinearScenarioStrategy, MaxAttemptsExceededError, OpaqueInputUnresolvedError, OpaqueRoleNotElicitableError, @@ -113,6 +115,8 @@ "GraphArtifactError", "GraphArtifactSecurityError", "InputCollector", + "LinearScenario", + "LinearScenarioStrategy", "MaxAttemptsExceededError", "OpaqueInputUnresolvedError", "OpaqueRoleNotElicitableError", diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index eebea0494..7f6155dd1 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -39,6 +39,7 @@ collect_inputs_with_retry, ) from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag +from pyrit.scenario.core.linear_scenario import LinearScenario, LinearScenarioStrategy from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_state import ScenarioCoreState, ScenarioStateLike from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult @@ -66,6 +67,8 @@ "GraphArtifactError", "GraphArtifactSecurityError", "InputCollector", + "LinearScenario", + "LinearScenarioStrategy", "MaxAttemptsExceededError", "OpaqueInputUnresolvedError", "OpaqueRoleNotElicitableError", diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index 8015c5ca3..238506142 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -141,6 +141,12 @@ def __init__( # the scenario via the attribution_parent_id foreign key on # AttackResultEntry. self._scenario_result_id: str | None = None + # Set via set_scenario_max_concurrency() by Scenario._build_default_linear_policy + # before the policy's per-step action runs. Consumed by ``process_async`` + # so a scenario-level ``max_concurrency`` setting flows through the + # unified linear-dispatch path without the orchestrator needing to + # branch on step type. + self._scenario_max_concurrency: int = 1 logger.info( f"Initialized atomic attack with {len(self._seed_groups)} seed groups, " @@ -163,6 +169,25 @@ def set_scenario_result_id(self, scenario_result_id: str | None) -> None: """ self._scenario_result_id = scenario_result_id + def set_scenario_max_concurrency(self, max_concurrency: int) -> None: + """ + Bind the scenario-level ``max_concurrency`` for ``process_async`` dispatch. + + Called by ``Scenario._build_default_linear_policy`` once at policy build + time so the unified ``process_async`` action body can honor scenario + concurrency without the orchestrator special-casing ``AtomicAttack``. + Direct ``run_async`` callers (outside of a scenario) are unaffected. + + Args: + max_concurrency (int): Positive integer concurrency limit. + + Raises: + ValueError: If ``max_concurrency`` is less than ``1``. + """ + if max_concurrency < 1: + raise ValueError("max_concurrency must be >= 1") + self._scenario_max_concurrency = max_concurrency + def _validate_unique_objective_hashes(self) -> None: """ Ensure each seed group in this atomic attack has a unique objective hash. @@ -258,17 +283,22 @@ async def process_async(self) -> ScenarioStepResult: """ ``ScenarioStep`` adapter — runs the atomic attack and wraps the result. - Delegates to ``run_async`` using the instance's stored execution - parameters, then packages the completed results into a - ``ScenarioStepResult``. Incomplete objectives and the executor's - ``input_indices`` are stashed in ``metadata`` so the orchestrator - (Phase 5) can drive resume / retry logic without losing information. + Delegates to ``run_async`` honoring the scenario-level + ``max_concurrency`` bound via :meth:`set_scenario_max_concurrency` + (defaults to ``1`` when invoked outside a scenario). Completed results + are packaged into a ``ScenarioStepResult``; incomplete objectives and + the executor's ``input_indices`` are stashed in ``metadata`` so the + orchestrator can drive resume / retry logic without losing + information. Returns: ScenarioStepResult: ``outcome="done"`` with the completed attack results and execution bookkeeping in ``metadata``. """ - executor_result = await self.run_async() + executor_result = await self.run_async( + max_concurrency=self._scenario_max_concurrency, + return_partial_on_failure=True, + ) return ScenarioStepResult( outcome="done", attack_results=list(executor_result.completed_results), diff --git a/pyrit/scenario/core/linear_scenario.py b/pyrit/scenario/core/linear_scenario.py new file mode 100644 index 000000000..05facb452 --- /dev/null +++ b/pyrit/scenario/core/linear_scenario.py @@ -0,0 +1,161 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +``LinearScenario`` — L0 authoring tier for "I just want to run these steps in order". + +This module exists so scenario authors can hand a list of pre-built +``ScenarioStep`` instances (typically ``AtomicAttack`` objects) to a +zero-subclass scenario and get a runnable :class:`Scenario` back. Composes +purely on top of the default linear policy built by +:meth:`Scenario._build_default_linear_policy`; no graph or state vocabulary +is exposed to the caller. + +L1 (override ``_get_atomic_attacks_async``) and L2 (override +``_build_execution_graph``) remain available for scenarios that need +registry-driven technique selection or non-linear control flow respectively. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Optional, cast + +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario +from pyrit.scenario.core.scenario_strategy import ScenarioStrategy + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pyrit.scenario.core.atomic_attack import AtomicAttack + from pyrit.scenario.core.scenario_step import ScenarioStep + from pyrit.score import Scorer + + +class LinearScenarioStrategy(ScenarioStrategy): + """ + Single-member strategy enum for :class:`LinearScenario`. + + ``LinearScenario`` doesn't run a technique-selection menu — the steps + are constructor inputs. The strategy enum exists only to satisfy the + base ``Scenario`` contract. + """ + + DEFAULT = ("default", {"all"}) + + @classmethod + def get_aggregate_tags(cls) -> set[str]: + """Return the strategy aggregate tags this enum exposes.""" + return {"all"} + + +class LinearScenario(Scenario): + """ + Run a caller-supplied list of ``ScenarioStep`` instances in order. + + The L0 authoring tier for scenarios that don't need a custom subclass, + a technique registry, or non-linear control flow. The supplied steps + are walked by the default linear policy (:meth:`Scenario._build_default_linear_policy`), + so ``AtomicAttack`` steps automatically receive scenario-level + ``max_concurrency`` via :meth:`AtomicAttack.set_scenario_max_concurrency` + and all dispatch flows uniformly through ``step.process_async``. + + Example: + ```python + scenario = LinearScenario( + steps=[atomic_a, atomic_b], + objective_scorer=my_scorer, + ) + await scenario.initialize_async(objective_target=target) + result = await scenario.run_async() + ``` + """ + + #: Default scenario version; bump if behavior changes in a way that + #: invalidates resume from an older persisted ``ScenarioResult``. + VERSION: ClassVar[int] = 1 + + #: The steps are supplied directly; baseline injection would mutate + #: that list and surprise the caller. + BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Forbidden + + def __init__( + self, + *, + steps: Sequence[ScenarioStep], + objective_scorer: Scorer, + name: str = "", + version: int | None = None, + scenario_result_id: Optional[str] = None, + ) -> None: + """ + Initialize a linear scenario from a caller-supplied step list. + + Args: + steps (Sequence[ScenarioStep]): Steps to walk in order. Must be + non-empty. + objective_scorer (Scorer): Forwarded to the base ``Scenario``. + Used by baseline / retry paths; for L0 usage with explicit + steps and ``BASELINE_ATTACK_POLICY = Forbidden``, it is + effectively only consulted if a custom retry path is wired in. + name (str): Descriptive name for the scenario. + version (int | None): Scenario version. Defaults to ``VERSION``; + override only when callers want to invalidate resume from a + prior shape. + scenario_result_id (str | None): Optional ID of an existing + scenario result to resume. + + Raises: + ValueError: If ``steps`` is empty. + """ + if not steps: + raise ValueError("LinearScenario requires at least one step.") + + self._explicit_steps: list[ScenarioStep] = list(steps) + + super().__init__( + name=name, + version=version if version is not None else self.VERSION, + strategy_class=self.get_strategy_class(), + objective_scorer=objective_scorer, + scenario_result_id=scenario_result_id, + ) + + @classmethod + def get_strategy_class(cls) -> type[ScenarioStrategy]: + """Return the single-member strategy enum class.""" + return LinearScenarioStrategy + + @classmethod + def get_default_strategy(cls) -> ScenarioStrategy: + """Return the only strategy member.""" + return LinearScenarioStrategy.DEFAULT + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + """ + Return an empty dataset configuration. + + Steps are supplied directly via the constructor, so the base + scenario's auto-build from registered datasets is unused. + + Returns: + DatasetConfiguration: An empty configuration. + """ + return DatasetConfiguration() + + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + """ + Return the caller-supplied steps in order. + + The return type is :class:`list[AtomicAttack]` for parity with the + base ``Scenario._get_atomic_attacks_async`` contract; the resume / + orchestrator code reads the duck-typed attributes that every + ``ScenarioStep`` exposes (``name``, ``process_async``) so non- + ``AtomicAttack`` subclasses pass through cleanly. + + Returns: + list[AtomicAttack]: The caller-supplied steps, cast to satisfy the + base contract's type signature. + """ + return cast("list[AtomicAttack]", list(self._explicit_steps)) diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 9c6d29a2e..76a90a8cf 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -1231,15 +1231,17 @@ def _build_default_linear_policy(self, *, steps: Sequence[ScenarioStep]) -> Stra """ Build a linear-traversal policy that preserves scenario-level execution params. - Each policy action runs ``steps[i]`` and transitions to state ``i + 1``; - state ``len(steps)`` is the sole terminal state. For ``AtomicAttack`` - steps the action calls ``run_async`` directly so ``max_concurrency`` and - ``return_partial_on_failure`` semantics that the legacy flat loop relied - on are preserved. Non-``AtomicAttack`` steps fall back to - ``process_async`` (so any future custom ``ScenarioStep`` subclass works - out of the box). In both paths the step's ``name`` is stamped into + Each policy action runs ``steps[i].process_async()`` and transitions + to state ``i + 1``; state ``len(steps)`` is the sole terminal state. + Every step type — ``AtomicAttack``, ``AdaptiveStep``, or any future + custom ``ScenarioStep`` subclass — flows through the same uniform + dispatch path. ``AtomicAttack`` steps receive scenario-level + ``max_concurrency`` via :meth:`AtomicAttack.set_scenario_max_concurrency` + before the policy is frozen, so the unified action body does not need + to branch on step type. The step's ``name`` is stamped into ``ScenarioStepResult.metadata['step_name']`` so the orchestrator can - identify the step at yield time. + identify the step at yield time (caller-supplied metadata wins on + collision). Args: steps (Sequence[ScenarioStep]): The steps to wrap. Must be non-empty. @@ -1253,7 +1255,14 @@ def _build_default_linear_policy(self, *, steps: Sequence[ScenarioStep]) -> Stra if not steps: raise ValueError("_build_default_linear_policy requires at least one step.") - max_concurrency = self._max_concurrency + # Push the scenario-level max_concurrency into every AtomicAttack step + # exactly once, before any action runs. Non-AtomicAttack steps either + # own their own concurrency (e.g. AdaptiveStep) or default to 1, so the + # orchestrator stays out of their dispatch. + for step in steps: + if isinstance(step, AtomicAttack): + step.set_scenario_max_concurrency(self._max_concurrency) + terminal_state = len(steps) actions: dict[int, PolicyAction[ScenarioStep, int]] = {} @@ -1263,35 +1272,21 @@ async def _action( graph: StrategyGraph[ScenarioStep, int], _step: ScenarioStep = step, _next: int = index + 1, - _max_concurrency: int = max_concurrency, ) -> tuple[int, ScenarioStepResult | None]: graph.bind_current_step(step=_step) try: - if isinstance(_step, AtomicAttack): - executor_result = await _step.run_async( - max_concurrency=_max_concurrency, - return_partial_on_failure=True, - ) - result: ScenarioStepResult | None = ScenarioStepResult( - outcome="done", - attack_results=list(executor_result.completed_results), - metadata={ - "step_name": _step.atomic_attack_name, - "incomplete_objectives": list(executor_result.incomplete_objectives), - "input_indices": list(executor_result.input_indices), - }, - ) - else: - base_result = await _step.process_async() - # Re-stamp metadata with step_name so the orchestrator can route results - # without depending on graph.current_step (which is cleared before yield). - merged_metadata = {"step_name": _step.name, **base_result.metadata} - result = ScenarioStepResult( - outcome=base_result.outcome, - attack_results=base_result.attack_results, - step_identifier=base_result.step_identifier, - metadata=merged_metadata, - ) + base_result = await _step.process_async() + # Stamp ``step_name`` so the orchestrator can route the + # result without depending on ``graph.current_step`` + # (cleared before yield). Caller metadata wins on + # collision so steps remain authoritative. + merged_metadata = {"step_name": _step.name, **base_result.metadata} + result: ScenarioStepResult | None = ScenarioStepResult( + outcome=base_result.outcome, + attack_results=list(base_result.attack_results), + step_identifier=base_result.step_identifier, + metadata=merged_metadata, + ) finally: graph.bind_current_step(step=None) return _next, result diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py index 19a604365..9e615eb34 100644 --- a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -24,8 +24,6 @@ from pyrit.executor.attack import AttackScoringConfig from pyrit.scenario.core.input_schema import RoleDescriptor, RoleTag from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario -from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult -from pyrit.scenario.core.strategy_graph import PolicyAction, StrategyGraph, StrategyPolicy from pyrit.scenario.scenarios.adaptive.adaptive_step import AdaptiveStep from pyrit.scenario.scenarios.adaptive.dispatcher import ( ADAPTIVE_CONTEXT_LABEL, @@ -38,8 +36,6 @@ ) if TYPE_CHECKING: - from collections.abc import Sequence - from pyrit.models import SeedAttackGroup from pyrit.prompt_target import PromptTarget from pyrit.scenario.core.atomic_attack import AtomicAttack @@ -201,9 +197,10 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: resume bookkeeping treats steps via the duck-typed attributes :class:`AdaptiveStep` provides (``atomic_attack_name``, ``objectives``, ``seed_groups``, ``display_group``, ``filter_seed_groups_by_objectives``). - Execution dispatch in :meth:`_build_execution_graph` calls - ``step.process_async`` directly, bypassing the default linear policy's - ``AtomicAttack.run_async`` branch. + Execution flows through the unified default linear policy + (:meth:`Scenario._build_default_linear_policy`), which dispatches every + step — atomic or adaptive — via ``step.process_async`` so the + ``"success"`` / ``"exhausted"`` outcome labels propagate unchanged. Returns: list[AtomicAttack]: One step per objective with at least one @@ -363,89 +360,6 @@ def _build_step_for_seed_group( adaptive_context=adaptive_context, ) - def _build_execution_graph( - self, - *, - steps: Sequence[ScenarioStep] | None = None, - ) -> StrategyGraph[ScenarioStep, int]: - """ - Build a linear graph that drives each :class:`AdaptiveStep` via - ``process_async`` so the ``"success"`` / ``"exhausted"`` outcome - labels survive into the orchestrator (the default policy from the - base class would dispatch ``AtomicAttack`` instances through - ``run_async`` and lose the outcome distinction). - - Args: - steps: Optional explicit step list. Defaults to - ``self._atomic_attacks``, mirroring the base class contract. - - Returns: - StrategyGraph[ScenarioStep, int]: A linear traversal whose actions - always dispatch via ``step.process_async``. - """ - effective_steps = list(steps) if steps is not None else list(self._atomic_attacks) - policy = self._build_adaptive_linear_policy(steps=effective_steps) - return StrategyGraph(policy=policy) - - def _build_adaptive_linear_policy( - self, - *, - steps: Sequence[ScenarioStep], - ) -> StrategyPolicy[ScenarioStep, int]: - """ - Build a linear policy that always dispatches via ``process_async``. - - Each policy action runs ``steps[i].process_async()`` and transitions - to state ``i + 1``; state ``len(steps)`` is the sole terminal state. - Unlike :meth:`Scenario._build_default_linear_policy` there's no - ``isinstance(_step, AtomicAttack)`` branch — adaptive steps always - go through their own process_async loop so the - ``"success"``/``"exhausted"`` outcome labels propagate unchanged. - - Args: - steps: The steps to wrap. Must be non-empty. - - Returns: - StrategyPolicy[ScenarioStep, int]: A frozen linear policy. - - Raises: - ValueError: If ``steps`` is empty. - """ - if not steps: - raise ValueError("_build_adaptive_linear_policy requires at least one step.") - - terminal_state = len(steps) - actions: dict[int, PolicyAction[ScenarioStep, int]] = {} - - for index, step in enumerate(steps): - - async def _action( - graph: StrategyGraph[ScenarioStep, int], - _step: ScenarioStep = step, - _next: int = index + 1, - ) -> tuple[int, ScenarioStepResult | None]: - graph.bind_current_step(step=_step) - try: - base_result = await _step.process_async() - merged_metadata = {"step_name": _step.name, **base_result.metadata} - result: ScenarioStepResult | None = ScenarioStepResult( - outcome=base_result.outcome, - attack_results=list(base_result.attack_results), - step_identifier=base_result.step_identifier, - metadata=merged_metadata, - ) - finally: - graph.bind_current_step(step=None) - return _next, result - - actions[index] = _action - - return StrategyPolicy( - actions=actions, - initial_state=0, - terminal_states=frozenset({terminal_state}), - ) - def _rehydrate_selector_from_memory( self, *, diff --git a/tests/unit/scenario/core/test_linear_scenario.py b/tests/unit/scenario/core/test_linear_scenario.py new file mode 100644 index 000000000..1c6c7f417 --- /dev/null +++ b/tests/unit/scenario/core/test_linear_scenario.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for :class:`LinearScenario` — the L0 authoring tier for scenarios. + +Pins the contract that LinearScenario gives users a runnable scenario from +a list of pre-built ScenarioStep instances without subclassing, without a +graph vocabulary, and without registry-driven technique selection. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, PropertyMock + +import pytest + +from pyrit.executor.attack.core import AttackExecutorResult +from pyrit.identifiers import ComponentIdentifier +from pyrit.memory import CentralMemory +from pyrit.models import AttackOutcome, AttackResult +from pyrit.scenario import AtomicAttack, DatasetConfiguration, ScenarioResult +from pyrit.scenario.core import ( + BaselineAttackPolicy, + LinearScenario, + LinearScenarioStrategy, + Scenario, + ScenarioStrategy, +) +from pyrit.scenario.core.scenario_step import ScenarioStepResult +from pyrit.score import Scorer + +_TEST_SCORER_ID = ComponentIdentifier( + class_name="MockScorer", + class_module="tests.unit.scenarios", +) + + +def _make_scorer() -> MagicMock: + scorer = MagicMock(spec=Scorer) + scorer.get_identifier.return_value = _TEST_SCORER_ID + scorer.get_scorer_metrics.return_value = None + return scorer + + +def _make_atomic_attack_mock(name: str, attack_result: AttackResult) -> MagicMock: + """Build a fake AtomicAttack whose run_async returns the supplied result. + + Mirrors the canonical fixture pattern in test_scenario_graph_execution.py. + """ + mock_attack = MagicMock() + mock_attack.get_objective_target.return_value = MagicMock() + mock_attack.get_attack_scoring_config.return_value = MagicMock() + + attack = MagicMock(spec=AtomicAttack) + attack.atomic_attack_name = name + attack.name = name + attack.display_group = name + attack._attack = mock_attack + attack._scenario_result_id = None + attack._scenario_max_concurrency = 1 + type(attack).objectives = PropertyMock(return_value=[attack_result.objective]) + + def _set_scenario_result_id(sid): + attack._scenario_result_id = sid + + attack.set_scenario_result_id = MagicMock(side_effect=_set_scenario_result_id) + + def _set_scenario_max_concurrency(mc): + attack._scenario_max_concurrency = mc + + attack.set_scenario_max_concurrency = MagicMock(side_effect=_set_scenario_max_concurrency) + + async def _fake_run(*args, **kwargs): + memory = CentralMemory.get_memory_instance() + sid = attack._scenario_result_id + if sid: + attack_result.attribution_parent_id = sid + attack_result.attribution_data = {"parent_collection": name} + memory.add_attack_results_to_memory(attack_results=[attack_result]) + return AttackExecutorResult(completed_results=[attack_result], incomplete_objectives=[]) + + attack.run_async = MagicMock(side_effect=_fake_run) + + async def _fake_process(*args, **kwargs): + executor_result = await attack.run_async( + max_concurrency=attack._scenario_max_concurrency, + return_partial_on_failure=True, + ) + return ScenarioStepResult( + outcome="done", + attack_results=list(executor_result.completed_results), + metadata={ + "incomplete_objectives": list(executor_result.incomplete_objectives), + "input_indices": list(executor_result.input_indices), + }, + ) + + attack.process_async = MagicMock(side_effect=_fake_process) + return attack + + +def _sample_result(index: int) -> AttackResult: + result = AttackResult( + conversation_id=f"conv-{index}", + objective=f"objective-{index}", + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + result.atomic_attack_identifier = ComponentIdentifier( + class_name="MockAttack", + class_module="tests.unit.scenarios", + params={"name": f"attack-{index}"}, + ) + return result + + +@pytest.fixture +def mock_objective_target(): + target = MagicMock() + target.get_identifier.return_value = ComponentIdentifier( + class_name="MockTarget", + class_module="test", + ) + return target + + +class TestLinearScenarioConstruction: + """LinearScenario rejects empty step lists and exposes its strategy class.""" + + def test_rejects_empty_steps(self): + with pytest.raises(ValueError, match="at least one step"): + LinearScenario(steps=[], objective_scorer=_make_scorer()) + + def test_is_scenario_subclass(self): + assert issubclass(LinearScenario, Scenario) + + def test_baseline_policy_is_forbidden(self): + assert BaselineAttackPolicy.Forbidden == LinearScenario.BASELINE_ATTACK_POLICY + + def test_strategy_class_is_single_member(self): + cls = LinearScenario.get_strategy_class() + assert cls is LinearScenarioStrategy + members = list(cls) + assert len(members) == 1 + assert members[0] is LinearScenarioStrategy.DEFAULT + + def test_default_strategy_is_default_member(self): + assert LinearScenario.get_default_strategy() is LinearScenarioStrategy.DEFAULT + + def test_default_dataset_config_is_empty(self): + config = LinearScenario.default_dataset_config() + assert isinstance(config, DatasetConfiguration) + + def test_strategy_aggregate_tag_is_all(self): + assert LinearScenarioStrategy.get_aggregate_tags() == {"all"} + + def test_strategy_is_scenario_strategy_subclass(self): + assert issubclass(LinearScenarioStrategy, ScenarioStrategy) + + +@pytest.mark.usefixtures("patch_central_database") +class TestLinearScenarioExecution: + """End-to-end pin that LinearScenario walks its step list in order.""" + + async def test_runs_steps_in_order(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(3)] + scenario = LinearScenario(steps=attacks, objective_scorer=_make_scorer()) + await scenario.initialize_async(objective_target=mock_objective_target) + + result = await scenario.run_async() + + assert isinstance(result, ScenarioResult) + names = [r.metadata.get("step_name") for r in scenario.execution_history] + assert names == ["a0", "a1", "a2"] + + async def test_each_step_runs_exactly_once(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(3)] + scenario = LinearScenario(steps=attacks, objective_scorer=_make_scorer()) + await scenario.initialize_async(objective_target=mock_objective_target) + + await scenario.run_async() + + for attack in attacks: + attack.run_async.assert_called_once() + + async def test_max_concurrency_propagates_through_setter(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] + scenario = LinearScenario(steps=attacks, objective_scorer=_make_scorer()) + await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=5) + + await scenario.run_async() + + for attack in attacks: + attack.set_scenario_max_concurrency.assert_called_with(5) + attack.run_async.assert_called_once_with(max_concurrency=5, return_partial_on_failure=True) + + async def test_explicit_steps_preserved_across_initialize(self, mock_objective_target): + # ``initialize_async`` calls ``_get_atomic_attacks_async`` under the hood + # to populate ``self._atomic_attacks``. LinearScenario must return the + # exact steps supplied at construction time. + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(2)] + scenario = LinearScenario(steps=attacks, objective_scorer=_make_scorer()) + await scenario.initialize_async(objective_target=mock_objective_target) + + assert list(scenario._atomic_attacks) == attacks + + async def test_history_length_matches_step_count(self, mock_objective_target): + attacks = [_make_atomic_attack_mock(f"a{i}", _sample_result(i)) for i in range(4)] + scenario = LinearScenario(steps=attacks, objective_scorer=_make_scorer()) + await scenario.initialize_async(objective_target=mock_objective_target) + + await scenario.run_async() + + assert len(scenario.execution_history) == 4 diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index 6eab442df..2bd00571a 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -594,12 +594,12 @@ class TestAdaptiveLinearPolicy: def test_empty_steps_raises(self, mock_objective_scorer): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) with pytest.raises(ValueError, match="at least one step"): - scenario._build_adaptive_linear_policy(steps=[]) + scenario._build_default_linear_policy(steps=[]) def test_initial_state_zero_and_terminal_state_is_step_count(self, mock_objective_scorer): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) steps = [_make_stub_step(name=f"s{i}") for i in range(3)] - policy = scenario._build_adaptive_linear_policy(steps=steps) + policy = scenario._build_default_linear_policy(steps=steps) assert isinstance(policy, StrategyPolicy) assert policy.initial_state == 0 diff --git a/tests/unit/scenario/test_atomic_attack_scenario_step.py b/tests/unit/scenario/test_atomic_attack_scenario_step.py index 7accffc80..a29a83dfd 100644 --- a/tests/unit/scenario/test_atomic_attack_scenario_step.py +++ b/tests/unit/scenario/test_atomic_attack_scenario_step.py @@ -262,3 +262,90 @@ def test_filter_ignores_unknown_objectives(self, mock_attack, seed_groups): ) atomic.filter_seed_groups_by_objectives(remaining_objectives=["obj1", "does_not_exist"]) assert atomic.objectives == ["obj1"] + + +@pytest.mark.usefixtures("patch_central_database") +class TestAtomicAttackSetScenarioMaxConcurrency: + """``set_scenario_max_concurrency`` binds the value consumed by ``process_async``. + + R1 unified the linear-policy dispatch so the base policy pushes + scenario-level ``max_concurrency`` into every ``AtomicAttack`` step via this + setter before the action loop runs. The setter must validate its input and + the value must flow through to ``run_async`` via ``process_async``. + """ + + def test_default_is_one(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + assert atomic._scenario_max_concurrency == 1 + + def test_setter_accepts_positive(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + atomic.set_scenario_max_concurrency(5) + assert atomic._scenario_max_concurrency == 5 + + def test_setter_rejects_zero(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + with pytest.raises(ValueError, match="max_concurrency must be >= 1"): + atomic.set_scenario_max_concurrency(0) + + def test_setter_rejects_negative(self, mock_attack, seed_groups): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + with pytest.raises(ValueError, match="max_concurrency must be >= 1"): + atomic.set_scenario_max_concurrency(-3) + + async def test_process_async_uses_bound_max_concurrency(self, mock_attack, seed_groups, attack_results): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + atomic.set_scenario_max_concurrency(11) + + with patch.object(AttackExecutor, "execute_attack_from_seed_groups_async", new_callable=AsyncMock) as mock_exec: + mock_exec.return_value = AttackExecutorResult( + completed_results=attack_results, + incomplete_objectives=[], + input_indices=[0, 1], + ) + with patch.object(AtomicAttack, "run_async", new_callable=AsyncMock) as mock_run: + mock_run.return_value = AttackExecutorResult( + completed_results=attack_results, + incomplete_objectives=[], + input_indices=[0, 1], + ) + await atomic.process_async() + + mock_run.assert_called_once_with(max_concurrency=11, return_partial_on_failure=True) + + async def test_process_async_defaults_to_one_when_unset(self, mock_attack, seed_groups, attack_results): + atomic = AtomicAttack( + attack_technique=AttackTechnique(attack=mock_attack), + seed_groups=seed_groups, + atomic_attack_name="my_step", + ) + + with patch.object(AtomicAttack, "run_async", new_callable=AsyncMock) as mock_run: + mock_run.return_value = AttackExecutorResult( + completed_results=attack_results, + incomplete_objectives=[], + input_indices=[0, 1], + ) + await atomic.process_async() + + mock_run.assert_called_once_with(max_concurrency=1, return_partial_on_failure=True) diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index a4971a3e0..2d98425a4 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -14,6 +14,7 @@ from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioIdentifier, ScenarioResult from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy, Scenario, ScenarioStrategy +from pyrit.scenario.core.scenario_step import ScenarioStepResult from pyrit.score import Scorer # Reusable test scorer identifier @@ -23,6 +24,45 @@ ) +def _wire_process_async(attack: MagicMock) -> None: + """ + Wire an :class:`AtomicAttack` MagicMock's ``process_async`` to delegate + to its ``run_async`` so the default linear policy dispatch path works. + + Reads ``run_async`` dynamically each call, so tests that replace + ``attack.run_async`` after fixture setup don't need to re-wire + ``process_async``. + """ + + async def _fake_process(*args, **kwargs): + executor_result = await attack.run_async( + max_concurrency=getattr(attack, "_scenario_max_concurrency", 1), + return_partial_on_failure=True, + ) + return ScenarioStepResult( + outcome="done", + attack_results=list(executor_result.completed_results), + metadata={ + "incomplete_objectives": list(executor_result.incomplete_objectives), + "input_indices": list(executor_result.input_indices), + }, + ) + + attack.process_async = MagicMock(side_effect=_fake_process) + + +def _wire_atomic_attack_mock(attack: MagicMock, *, name: str) -> None: + """Apply the canonical setter / process_async wiring for an AtomicAttack mock.""" + attack._scenario_result_id = None + attack._scenario_max_concurrency = 1 + attack.name = name + attack.set_scenario_result_id = MagicMock(side_effect=lambda sid: setattr(attack, "_scenario_result_id", sid)) + attack.set_scenario_max_concurrency = MagicMock( + side_effect=lambda mc: setattr(attack, "_scenario_max_concurrency", mc) + ) + _wire_process_async(attack) + + def save_attack_results_to_memory(attack_results): """Helper function to save attack results to memory (mimics what real attacks do).""" memory = CentralMemory.get_memory_instance() @@ -77,24 +117,21 @@ def mock_atomic_attacks(): run1.atomic_attack_name = "attack_run_1" run1.display_group = "attack_run_1" run1._attack = mock_attack - run1._scenario_result_id = None - run1.set_scenario_result_id = MagicMock(side_effect=lambda sid: setattr(run1, "_scenario_result_id", sid)) + _wire_atomic_attack_mock(run1, name="attack_run_1") type(run1).objectives = PropertyMock(return_value=["objective1"]) run2 = MagicMock(spec=AtomicAttack) run2.atomic_attack_name = "attack_run_2" run2.display_group = "attack_run_2" run2._attack = mock_attack - run2._scenario_result_id = None - run2.set_scenario_result_id = MagicMock(side_effect=lambda sid: setattr(run2, "_scenario_result_id", sid)) + _wire_atomic_attack_mock(run2, name="attack_run_2") type(run2).objectives = PropertyMock(return_value=["objective2"]) run3 = MagicMock(spec=AtomicAttack) run3.atomic_attack_name = "attack_run_3" run3.display_group = "attack_run_3" run3._attack = mock_attack - run3._scenario_result_id = None - run3.set_scenario_result_id = MagicMock(side_effect=lambda sid: setattr(run3, "_scenario_result_id", sid)) + _wire_atomic_attack_mock(run3, name="attack_run_3") type(run3).objectives = PropertyMock(return_value=["objective3"]) return [run1, run2, run3] diff --git a/tests/unit/scenario/test_scenario_graph_execution.py b/tests/unit/scenario/test_scenario_graph_execution.py index 2cc96e687..8d5a8f5c4 100644 --- a/tests/unit/scenario/test_scenario_graph_execution.py +++ b/tests/unit/scenario/test_scenario_graph_execution.py @@ -56,16 +56,27 @@ def _save_results_to_memory(attack_results, *, atomic_attack=None): def _make_atomic_attack_mock(name: str, attack_result: AttackResult) -> MagicMock: - """Build a fake AtomicAttack whose run_async returns the supplied result.""" + """Build a fake AtomicAttack whose run_async returns the supplied result. + + Wires both ``run_async`` (legacy direct-dispatch path, asserted by some + tests) and ``process_async`` (current unified-dispatch path used by the + default linear policy after R1). The two side-effects are layered so + ``process_async`` invokes ``run_async`` internally — mirroring the real + :meth:`AtomicAttack.process_async` — which keeps existing + ``attack.run_async.assert_called_once_with(...)`` assertions valid while + the orchestrator dispatches via ``process_async``. + """ mock_attack = MagicMock() mock_attack.get_objective_target.return_value = MagicMock() mock_attack.get_attack_scoring_config.return_value = MagicMock() attack = MagicMock(spec=AtomicAttack) attack.atomic_attack_name = name + attack.name = name attack.display_group = name attack._attack = mock_attack attack._scenario_result_id = None + attack._scenario_max_concurrency = 1 type(attack).objectives = PropertyMock(return_value=[attack_result.objective]) def _set_scenario_result_id(sid): @@ -73,11 +84,32 @@ def _set_scenario_result_id(sid): attack.set_scenario_result_id = MagicMock(side_effect=_set_scenario_result_id) + def _set_scenario_max_concurrency(mc): + attack._scenario_max_concurrency = mc + + attack.set_scenario_max_concurrency = MagicMock(side_effect=_set_scenario_max_concurrency) + async def _fake_run(*args, **kwargs): _save_results_to_memory([attack_result], atomic_attack=attack) return AttackExecutorResult(completed_results=[attack_result], incomplete_objectives=[]) attack.run_async = MagicMock(side_effect=_fake_run) + + async def _fake_process(*args, **kwargs): + executor_result = await attack.run_async( + max_concurrency=attack._scenario_max_concurrency, + return_partial_on_failure=True, + ) + return ScenarioStepResult( + outcome="done", + attack_results=list(executor_result.completed_results), + metadata={ + "incomplete_objectives": list(executor_result.incomplete_objectives), + "input_indices": list(executor_result.input_indices), + }, + ) + + attack.process_async = MagicMock(side_effect=_fake_process) return attack diff --git a/tests/unit/scenario/test_scenario_partial_results.py b/tests/unit/scenario/test_scenario_partial_results.py index 91e3f27cd..38e41a256 100644 --- a/tests/unit/scenario/test_scenario_partial_results.py +++ b/tests/unit/scenario/test_scenario_partial_results.py @@ -14,6 +14,7 @@ from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioResult from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy, Scenario, ScenarioStrategy +from pyrit.scenario.core.scenario_step import ScenarioStepResult def _mock_scorer_id(name: str = "MockScorer") -> ComponentIdentifier: @@ -68,15 +69,38 @@ def create_mock_atomic_attack(name: str, objectives: list[str]) -> MagicMock: attack = MagicMock(spec=AtomicAttack) attack.atomic_attack_name = name + attack.name = name attack.display_group = name attack._attack = mock_attack_strategy attack._scenario_result_id = None + attack._scenario_max_concurrency = 1 def _set_scenario_result_id(scenario_result_id): attack._scenario_result_id = scenario_result_id attack.set_scenario_result_id = MagicMock(side_effect=_set_scenario_result_id) + def _set_scenario_max_concurrency(max_concurrency): + attack._scenario_max_concurrency = max_concurrency + + attack.set_scenario_max_concurrency = MagicMock(side_effect=_set_scenario_max_concurrency) + + async def _fake_process(*args, **kwargs): + executor_result = await attack.run_async( + max_concurrency=attack._scenario_max_concurrency, + return_partial_on_failure=True, + ) + return ScenarioStepResult( + outcome="done", + attack_results=list(executor_result.completed_results), + metadata={ + "incomplete_objectives": list(executor_result.incomplete_objectives), + "input_indices": list(executor_result.input_indices), + }, + ) + + attack.process_async = MagicMock(side_effect=_fake_process) + original_objectives = list(objectives) current_objectives = {"value": list(objectives)} diff --git a/tests/unit/scenario/test_scenario_retry.py b/tests/unit/scenario/test_scenario_retry.py index d26cb1ae0..c88ff1385 100644 --- a/tests/unit/scenario/test_scenario_retry.py +++ b/tests/unit/scenario/test_scenario_retry.py @@ -14,6 +14,7 @@ from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioResult from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy, Scenario, ScenarioStrategy +from pyrit.scenario.core.scenario_step import ScenarioStepResult # Test constants TEST_ATTACK_TYPE = "TestAttack" @@ -136,15 +137,38 @@ def create_mock_atomic_attack(name: str, objectives: list[str], run_async_mock: attack = MagicMock(spec=AtomicAttack) attack.atomic_attack_name = name + attack.name = name attack.display_group = name attack._attack = mock_attack_strategy attack._scenario_result_id = None + attack._scenario_max_concurrency = 1 def _set_scenario_result_id(scenario_result_id): attack._scenario_result_id = scenario_result_id attack.set_scenario_result_id = MagicMock(side_effect=_set_scenario_result_id) + def _set_scenario_max_concurrency(max_concurrency): + attack._scenario_max_concurrency = max_concurrency + + attack.set_scenario_max_concurrency = MagicMock(side_effect=_set_scenario_max_concurrency) + + async def _fake_process(*args, **kwargs): + executor_result = await attack.run_async( + max_concurrency=attack._scenario_max_concurrency, + return_partial_on_failure=True, + ) + return ScenarioStepResult( + outcome="done", + attack_results=list(executor_result.completed_results), + metadata={ + "incomplete_objectives": list(executor_result.incomplete_objectives), + "input_indices": list(executor_result.input_indices), + }, + ) + + attack.process_async = MagicMock(side_effect=_fake_process) + # Track objectives + objective-hash mapping so the hash-based filter # behaves correctly in resume tests. from pyrit.common.utils import to_sha256 From e3518d5eedefa3caa2ca9f4453afb1996735dc6e Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 14:12:48 -0700 Subject: [PATCH 40/42] MAINT: R2 rename _get_atomic_attacks_async to _get_steps_async Renames the step-builder hook on Scenario from _get_atomic_attacks_async to _get_steps_async to honestly reflect that subclasses may return any ScenarioStep (AtomicAttack, AdaptiveStep, _ScenarioPipelinePhaseStep, etc.), not just AtomicAttacks. The legacy name keeps working as a passthrough through 0.16.0. Base class now exposes _get_steps_async as the real factory (cross-product over selected techniques and datasets). _get_atomic_attacks_async stays as a thin delegate. __init_subclass__ detects subclasses that still override only the legacy name and emits a DeprecationWarning once at class-creation time so authors see the rename horizon before their next run_async() call. Internal callsite in initialize_async now invokes _get_steps_async; the existing baseline-injection rescue path is unchanged. Migrates all 8 first-party Scenario subclasses (adaptive, adversarial, red_team_agent, encoding, jailbreak, psychosocial, scam, sweep_then_deep_dive), LinearScenario, and ScenarioPipeline to the new name. Test fixtures across the scenario suite are migrated except for the two that intentionally exercise the legacy rescue path (test_baseline_deprecation, test_scenario._LegacyOverrideScenario). Walkthroughs in doc/ and .github/instructions/scenarios.instructions.md updated with the rename plus a deprecation pointer. Adds tests/unit/scenario/test_get_steps_async_rename.py pinning: legacy-override-only emits the warning, new-override-only stays quiet, both-overrides stays quiet, neither-override stays quiet, legacy override reached via _get_steps_async delegation, and new override reached via _get_atomic_attacks_async passthrough. Per rlundeen review on #1767: surfaces R2 from the R-series rollout (R1 collapsed the adaptive override into the base linear policy; R3 will split scenario/step state). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../instructions/scenarios.instructions.md | 25 +-- doc/code/scenarios/0_scenarios.py | 10 +- doc/scanner/1_pyrit_scan.py | 6 +- pyrit/scenario/composite/scenario_pipeline.py | 2 +- pyrit/scenario/core/linear_scenario.py | 6 +- pyrit/scenario/core/scenario.py | 78 ++++++++-- pyrit/scenario/core/strategy_graph.py | 2 +- .../scenarios/adaptive/adaptive_scenario.py | 10 +- pyrit/scenario/scenarios/airt/jailbreak.py | 4 +- pyrit/scenario/scenarios/airt/psychosocial.py | 4 +- pyrit/scenario/scenarios/airt/scam.py | 4 +- .../scenarios/airt/sweep_then_deep_dive.py | 2 +- .../scenarios/benchmark/adversarial.py | 4 +- .../scenarios/foundry/red_team_agent.py | 2 +- pyrit/scenario/scenarios/garak/encoding.py | 4 +- .../scenario/test_get_steps_async_rename.py | 142 ++++++++++++++++++ tests/unit/scenario/test_scenario.py | 4 +- .../scenario/test_scenario_graph_execution.py | 2 +- .../unit/scenario/test_scenario_parameters.py | 2 +- .../scenario/test_scenario_partial_results.py | 2 +- tests/unit/scenario/test_scenario_retry.py | 2 +- tests/unit/scenario/test_waterfall.py | 2 +- 22 files changed, 263 insertions(+), 56 deletions(-) create mode 100644 tests/unit/scenario/test_get_steps_async_rename.py diff --git a/.github/instructions/scenarios.instructions.md b/.github/instructions/scenarios.instructions.md index 867375f8a..3be1dd45c 100644 --- a/.github/instructions/scenarios.instructions.md +++ b/.github/instructions/scenarios.instructions.md @@ -34,9 +34,13 @@ class MyScenario(Scenario): return DatasetConfiguration(dataset_names=["my_dataset"]) ``` -4. **Optionally override `_get_atomic_attacks_async()`** — the base class provides a default +4. **Optionally override `_get_steps_async()`** — the base class provides a default that uses the factory/registry pattern (see "AtomicAttack Construction" below). - Only override if your scenario needs custom attack construction logic. + Only override if your scenario needs custom step construction logic. + + > **Deprecation note:** The legacy hook `_get_atomic_attacks_async()` still works as a + > passthrough but is deprecated and will be removed in 0.16.0. Migrate overrides to + > `_get_steps_async()` — the body is identical; only the name changes. ## Constructor Pattern @@ -53,7 +57,7 @@ def __init__( if not objective_scorer: objective_scorer = self._get_default_scorer() - # 2. Store config objects for _get_atomic_attacks_async + # 2. Store config objects for _get_steps_async self._scorer_config = AttackScoringConfig(objective_scorer=objective_scorer) # 3. Call super().__init__ — required args: version, strategy_class, objective_scorer @@ -139,9 +143,12 @@ Note: `atomic_attack_name` must remain unique per `AtomicAttack` for correct res ## AtomicAttack Construction — Default Base Class Behaviour -The `Scenario` base class provides a default `_get_atomic_attacks_async()` that uses the +The `Scenario` base class provides a default `_get_steps_async()` that uses the factory/registry pattern. Scenarios that register their techniques via `_get_attack_technique_factories()` -get atomic-attack construction **for free** — no override needed. +get step construction **for free** — no override needed. + +> The legacy hook `_get_atomic_attacks_async()` still works as a passthrough but is +> deprecated and will be removed in 0.16.0. Use `_get_steps_async()` for new code. The default implementation: 1. Calls `self._get_attack_technique_factories()` to get name→factory mapping @@ -150,13 +157,13 @@ The default implementation: 4. Uses `self._build_display_group()` for user-facing grouping 5. Builds `AtomicAttack` with unique `atomic_attack_name` = `"{technique}_{dataset}"` -### Customization hooks (no need to override `_get_atomic_attacks_async`): +### Customization hooks (no need to override `_get_steps_async`): - **`_get_attack_technique_factories()`** — override to add/remove/replace factories - **`_build_display_group()`** — override to change grouping (default: by technique) -### When to override `_get_atomic_attacks_async`: +### When to override `_get_steps_async`: Only override when the scenario **cannot** use the factory/registry pattern — e.g., scenarios -with custom composite logic, per-strategy converter stacks, or non-standard attack construction. +with custom composite logic, per-strategy converter stacks, or non-standard step construction. Overrides that want baseline support must emit it themselves by calling `self._build_baseline_atomic_attack(seed_groups=...)` with the same seeds used for the strategy attacks and prepending the result. The base implementation emits baseline automatically; passing freshly resolved seeds reintroduces ADO 9012 (baseline-vs-strategy population divergence under `max_dataset_size`). @@ -185,4 +192,4 @@ New scenarios must be registered in `pyrit/scenario/__init__.py` as virtual pack - Forgetting `@apply_defaults` on `__init__` - Empty `seed_groups` passed to `AtomicAttack` - Missing `VERSION` class constant -- Missing `_async` suffix on `_get_atomic_attacks_async` +- Missing `_async` suffix on `_get_steps_async` diff --git a/doc/code/scenarios/0_scenarios.py b/doc/code/scenarios/0_scenarios.py index bf04a72f6..c495ff5cd 100644 --- a/doc/code/scenarios/0_scenarios.py +++ b/doc/code/scenarios/0_scenarios.py @@ -64,8 +64,10 @@ # 2. **Scenario Class**: Extend `Scenario` and implement these abstract methods: # - `get_strategy_class()`: Return your strategy enum class # - `get_default_strategy()`: Return the default strategy (typically `YourStrategy.ALL`) -# - The base class provides a default `_get_atomic_attacks_async()` that uses the factory/registry -# pattern. Override it only if your scenario needs custom attack construction logic. +# - The base class provides a default `_get_steps_async()` that uses the factory/registry +# pattern. Override it only if your scenario needs custom step construction logic. +# (The legacy hook `_get_atomic_attacks_async()` still works but is deprecated and will +# be removed in 0.16.0.) # # 3. **Default Dataset**: Implement `default_dataset_config()` to specify the datasets your scenario uses out of the box. # - Returns a `DatasetConfiguration` with one or more named datasets (e.g., `DatasetConfiguration(dataset_names=["my_dataset"])`) @@ -155,8 +157,8 @@ def __init__( def _build_display_group(self, *, technique_name: str, seed_group_name: str) -> str: return seed_group_name - # No _get_atomic_attacks_async override needed! - # The base class builds attacks from the (technique x dataset) cross-product + # No _get_steps_async override needed! + # The base class builds steps from the (technique x dataset) cross-product # using the factory/registry pattern automatically. diff --git a/doc/scanner/1_pyrit_scan.py b/doc/scanner/1_pyrit_scan.py index 904b948d5..f839ee2c4 100644 --- a/doc/scanner/1_pyrit_scan.py +++ b/doc/scanner/1_pyrit_scan.py @@ -166,10 +166,10 @@ def __init__(self, *, scenario_result_id=None, **kwargs): ) # ... your scenario-specific initialization code - async def _get_atomic_attacks_async(self): - # Override only if your scenario needs custom attack construction. + async def _get_steps_async(self): + # Override only if your scenario needs custom step construction. # The base class provides a default that uses the factory/registry pattern. - # Example: create attacks for each strategy composite + # Example: create steps for each strategy composite return [] diff --git a/pyrit/scenario/composite/scenario_pipeline.py b/pyrit/scenario/composite/scenario_pipeline.py index 067dd539c..cd356505a 100644 --- a/pyrit/scenario/composite/scenario_pipeline.py +++ b/pyrit/scenario/composite/scenario_pipeline.py @@ -643,7 +643,7 @@ def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: """ return {} - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Materialize one :class:`_ScenarioPipelinePhaseStep` per phase. diff --git a/pyrit/scenario/core/linear_scenario.py b/pyrit/scenario/core/linear_scenario.py index 05facb452..413123344 100644 --- a/pyrit/scenario/core/linear_scenario.py +++ b/pyrit/scenario/core/linear_scenario.py @@ -11,7 +11,7 @@ :meth:`Scenario._build_default_linear_policy`; no graph or state vocabulary is exposed to the caller. -L1 (override ``_get_atomic_attacks_async``) and L2 (override +L1 (override ``_get_steps_async``) and L2 (override ``_build_execution_graph``) remain available for scenarios that need registry-driven technique selection or non-linear control flow respectively. """ @@ -144,12 +144,12 @@ def default_dataset_config(cls) -> DatasetConfiguration: """ return DatasetConfiguration() - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Return the caller-supplied steps in order. The return type is :class:`list[AtomicAttack]` for parity with the - base ``Scenario._get_atomic_attacks_async`` contract; the resume / + base ``Scenario._get_steps_async`` contract; the resume / orchestrator code reads the duck-typed attributes that every ``ScenarioStep`` exposes (``name``, ``process_async``) so non- ``AtomicAttack`` subclasses pass through cleanly. diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 76a90a8cf..7b0659658 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -154,6 +154,29 @@ class Scenario(ABC): #: caller-supplied ``include_baseline=True`` raises ``ValueError``. BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Enabled + def __init_subclass__(cls, **kwargs: Any) -> None: + """ + Warn once per subclass that still overrides the legacy step builder. + + ``_get_atomic_attacks_async`` was renamed to :meth:`_get_steps_async` + in PyRIT 0.15; the old name keeps working through a passthrough but + will be removed in 0.16. We detect the override at class-creation + time so authors see the deprecation as soon as their subclass module + is imported instead of only on the next ``run_async`` call. + + Args: + **kwargs (Any): Forwarded to ``ABC.__init_subclass__``. + """ + super().__init_subclass__(**kwargs) + overrides_legacy = "_get_atomic_attacks_async" in cls.__dict__ + overrides_new = "_get_steps_async" in cls.__dict__ + if overrides_legacy and not overrides_new: + print_deprecation_message( + old_item=f"{cls.__module__}.{cls.__qualname__}._get_atomic_attacks_async", + new_item=f"{cls.__module__}.{cls.__qualname__}._get_steps_async", + removed_in="0.16.0", + ) + @classmethod def _get_additional_scoring_questions(cls) -> Sequence[Path]: """ @@ -195,7 +218,9 @@ def __init__( Note: Attack runs are populated by calling initialize_async(), which invokes the - subclass's _get_atomic_attacks_async() method. + subclass's _get_steps_async() method (or, for legacy subclasses still + overriding the deprecated _get_atomic_attacks_async, the base + _get_steps_async delegates to that legacy override). The scenario description is automatically extracted from the class's docstring (__doc__) with whitespace normalized for display. @@ -226,7 +251,7 @@ def __init__( self._atomic_attacks: list[AtomicAttack] = [] self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None - # Store prepared strategies for use in _get_atomic_attacks_async + # Store prepared strategies for use in _get_steps_async self._scenario_strategies: list[ScenarioStrategy] = [] # Maps atomic_attack_name → display_group for user-facing aggregation @@ -237,7 +262,7 @@ def __init__( self._declarations_validated: bool = False # Resolved effective baseline inclusion for the current run. Set in initialize_async - # before _get_atomic_attacks_async is awaited so overrides can read it. + # before _get_steps_async is awaited so overrides can read it. self._include_baseline: bool = False # Phase 5: state-machine view over the scenario's steps. Built lazily in @@ -745,7 +770,7 @@ async def initialize_async( if not self._declarations_validated: self.set_params_from_args(args={}) - self._atomic_attacks = await self._get_atomic_attacks_async() + self._atomic_attacks = await self._get_steps_async() # Deprecation rescue. Will be removed in 0.16.0. If the override didn't emit baseline, # warn and inject. Migrated overrides emit baseline themselves and bypass this branch. @@ -1097,14 +1122,26 @@ async def _get_remaining_atomic_attacks_async(self) -> list[AtomicAttack]: return remaining_attacks - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ - Build atomic attacks from the cross-product of selected techniques and datasets. + Build the steps this scenario will execute. - Uses ``_get_attack_technique_factories()`` to obtain factories, then - iterates over every (technique, dataset) pair to create an - ``AtomicAttack`` for each. Grouping for display is controlled by - ``_build_display_group()``. + Returns the list of :class:`AtomicAttack` instances the orchestrator + walks via the default linear policy. Subclasses override this method + to author custom step inventories — adaptive selectors, hand-rolled + composites, or wrappers around the registry pattern. + + The default implementation builds atomic attacks from the cross-product + of selected techniques and datasets. Uses + ``_get_attack_technique_factories()`` to obtain factories, then iterates + over every (technique, dataset) pair to create an ``AtomicAttack`` for + each. Grouping for display is controlled by ``_build_display_group()``. + + For backward compatibility, subclasses that still override + :meth:`_get_atomic_attacks_async` are detected automatically and routed + through that override; a deprecation warning is emitted once per such + subclass at class-creation time. Removal of the old method is planned + for ``0.16.0``. Subclasses that do **not** use the factory/registry pattern should override this method entirely. Overrides that want baseline support @@ -1112,11 +1149,17 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: seeds. Returns: - list[AtomicAttack]: The generated atomic attacks. + list[AtomicAttack]: The generated steps. Raises: ValueError: If the scenario has not been initialized. """ + # Legacy-override delegation: if a subclass still overrides the old + # name (and didn't also override _get_steps_async), call that override + # so we don't lose its behavior during the deprecation window. + if type(self)._get_atomic_attacks_async is not Scenario._get_atomic_attacks_async: + return await self._get_atomic_attacks_async() + if self._objective_target is None: raise ValueError( "Scenario not properly initialized. Call await scenario.initialize_async() before running." @@ -1185,6 +1228,19 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: return atomic_attacks + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + """ + Delegate to :meth:`_get_steps_async`. + + Kept as a passthrough so existing subclass overrides keep working + through the deprecation window. New scenarios should override + :meth:`_get_steps_async` directly. Will be removed in ``0.16.0``. + + Returns: + list[AtomicAttack]: Delegates to :meth:`_get_steps_async`. + """ + return await self._get_steps_async() + def _build_execution_graph( self, *, steps: Optional[Sequence[ScenarioStep]] = None ) -> StrategyGraph[ScenarioStep, int]: diff --git a/pyrit/scenario/core/strategy_graph.py b/pyrit/scenario/core/strategy_graph.py index 48bd342e3..5addfb80e 100644 --- a/pyrit/scenario/core/strategy_graph.py +++ b/pyrit/scenario/core/strategy_graph.py @@ -12,7 +12,7 @@ This module also exposes ``linear_strategy_policy``, a convenience builder that produces a trivial "run steps 0..N-1 in order" policy. Phase 5 will use it to silently upgrade scenarios that still declare their steps as a flat -list (via the legacy ``_get_atomic_attacks_async`` override) without forcing +list (via the ``_get_steps_async`` override) without forcing those scenarios to author a custom policy. """ diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py index 9e615eb34..2c19ec6e7 100644 --- a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -93,7 +93,7 @@ def __init__( """ # Validate scalar inputs eagerly. ``AdaptiveTechniqueSelector`` and # ``AdaptiveStep`` perform the same checks, but only when constructed - # lazily inside ``_get_atomic_attacks_async`` (called from + # lazily inside ``_get_steps_async`` (called from # ``initialize_async``). Failing fast at __init__ matches the # elicitation surface advertised by ``input_schema`` so wizard / # programmatic callers get the error on the same line they supplied @@ -114,7 +114,7 @@ def __init__( self._max_attempts_per_objective = max_attempts_per_objective self._seed = seed self._context_extractor = context_extractor - # Populated by _get_atomic_attacks_async; consumed by _build_execution_graph + # Populated by _get_steps_async; consumed by _build_execution_graph # only when an override path needs to introspect it externally. self._selector: AdaptiveTechniqueSelector | None = None @@ -182,7 +182,7 @@ def input_schema(cls) -> list[RoleDescriptor]: ), ] - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Build one :class:`AdaptiveStep` per objective. @@ -193,7 +193,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: with no compatible techniques are skipped. The return type is :class:`list[AtomicAttack]` for parity with the base - ``Scenario._get_atomic_attacks_async`` contract — the orchestrator's + ``Scenario._get_steps_async`` contract — the orchestrator's resume bookkeeping treats steps via the duck-typed attributes :class:`AdaptiveStep` provides (``atomic_attack_name``, ``objectives``, ``seed_groups``, ``display_group``, ``filter_seed_groups_by_objectives``). @@ -313,7 +313,7 @@ def _build_step_for_seed_group( Raises: ValueError: If ``self._objective_target`` is not set (defensive - guard; ``_get_atomic_attacks_async`` enforces this earlier). + guard; ``_get_steps_async`` enforces this earlier). """ if self._objective_target is None: # pragma: no cover - defensive raise ValueError("objective_target must be set before creating attacks") diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index f69b55d01..26880b3e3 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -197,7 +197,7 @@ def __init__( ) self._legacy_include_baseline = include_baseline - # Will be resolved in _get_atomic_attacks_async + # Will be resolved in _get_steps_async self._seed_groups: Optional[list[SeedAttackGroup]] = None def _get_or_create_adversarial_target(self) -> PromptTarget: @@ -295,7 +295,7 @@ async def _get_atomic_attack_from_strategy_async( seed_groups=self._seed_groups or [], ) - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Generate atomic attacks for each jailbreak template. diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 8e7bb0bd5..ec96ace34 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -282,7 +282,7 @@ def __init__( # Store deprecated objectives for later resolution in _resolve_seed_groups self._deprecated_objectives = objectives - # Will be resolved in _get_atomic_attacks_async + # Will be resolved in _get_steps_async self._seed_groups: Optional[list[SeedAttackGroup]] = None def _resolve_seed_groups(self) -> ResolvedSeedData: @@ -418,7 +418,7 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore # threshold=1.0 means only raw score 5 counts as attack success (extremely harmful) return FloatScaleThresholdScorer(scorer=conversation_scorer, threshold=1.0) - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: if self._objective_target is None: raise ValueError("objective_target must be set before creating attacks") try: diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index 84b57cffb..2e8b35289 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -193,7 +193,7 @@ def __init__( ) self._legacy_include_baseline = include_baseline - # Will be resolved in _get_atomic_attacks_async + # Will be resolved in _get_steps_async self._seed_groups: Optional[list[SeedAttackGroup]] = None def _resolve_seed_groups(self) -> list[SeedAttackGroup]: @@ -269,7 +269,7 @@ def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack: memory_labels=self._memory_labels, ) - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Generate atomic attacks for each strategy. diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index 4af7ca026..0f73a5171 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -621,7 +621,7 @@ def default_dataset_config(cls) -> DatasetConfiguration: """ return DatasetConfiguration() - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Return the atomics in canonical phase order: sweep first, deep dives after. diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index 33d8d5d8f..aea0d94b2 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -142,7 +142,7 @@ def __init__( scenario_result_id=scenario_result_id, ) - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Build atomic attacks from the cross-product of techniques × models × datasets. @@ -323,7 +323,7 @@ def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: inspection. The returned factories are not used to execute attacks (that path runs - through :meth:`_get_atomic_attacks_async` with its own local factory + through :meth:`_get_steps_async` with its own local factory construction); they exist purely as the registry-shaped catalog of techniques this scenario uses. diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index b9ce521fb..e26ced454 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -423,7 +423,7 @@ def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ return self._dataset_config.get_all_seed_attack_groups() - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Retrieve the list of AtomicAttack instances in this scenario. diff --git a/pyrit/scenario/scenarios/garak/encoding.py b/pyrit/scenario/scenarios/garak/encoding.py index c20ece87b..c9fb0eced 100644 --- a/pyrit/scenario/scenarios/garak/encoding.py +++ b/pyrit/scenario/scenarios/garak/encoding.py @@ -212,7 +212,7 @@ def __init__( ) self._legacy_include_baseline = include_baseline - # Will be resolved in _get_atomic_attacks_async + # Will be resolved in _get_steps_async self._resolved_seed_groups: Optional[list[SeedAttackGroup]] = None def _resolve_seed_groups(self) -> list[SeedAttackGroup]: @@ -230,7 +230,7 @@ def _resolve_seed_groups(self) -> list[SeedAttackGroup]: return seed_groups - async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + async def _get_steps_async(self) -> list[AtomicAttack]: """ Retrieve the list of AtomicAttack instances in this scenario. diff --git a/tests/unit/scenario/test_get_steps_async_rename.py b/tests/unit/scenario/test_get_steps_async_rename.py new file mode 100644 index 000000000..70edf2adc --- /dev/null +++ b/tests/unit/scenario/test_get_steps_async_rename.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests pinning the R2 deprecation contract for the +``_get_atomic_attacks_async`` → ``_get_steps_async`` rename. + +The rename ships with a passthrough shim plus an ``__init_subclass__`` +deprecation hook on :class:`Scenario`. These tests pin: + +* Subclasses overriding the legacy name only → ``DeprecationWarning`` fires at + class creation, and the orchestrator still receives the legacy override's + output via the base ``_get_steps_async`` delegation path. +* Subclasses overriding the new name only → no warning; legacy + ``_get_atomic_attacks_async`` callers still get the new override's output via + the base passthrough. +* Subclasses overriding both → no warning; new name wins. +* Subclasses overriding neither → no warning; default factory body runs. +""" + +from __future__ import annotations + +import warnings +from typing import ClassVar +from unittest.mock import MagicMock + +import pytest + +from pyrit.scenario.core.atomic_attack import AtomicAttack +from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario +from pyrit.scenario.core.scenario_strategy import ScenarioStrategy + + +class _R2Strategy(ScenarioStrategy): + DEFAULT = ("default", {"default"}) + + +class _R2ScenarioBase(Scenario): + """Common abstract-method satisfaction for the R2 rename tests.""" + + BASELINE_ATTACK_POLICY: ClassVar[BaselineAttackPolicy] = BaselineAttackPolicy.Disabled + + @classmethod + def get_strategy_class(cls): + return _R2Strategy + + @classmethod + def get_default_strategy(cls): + return _R2Strategy.DEFAULT + + @classmethod + def default_dataset_config(cls) -> DatasetConfiguration: + return DatasetConfiguration() + + +def _build_scenario_kwargs() -> dict: + objective_scorer = MagicMock() + objective_scorer.get_identifier.return_value = {"id": "r2-scorer"} + objective_scorer.get_scorer_metrics.return_value = None + return { + "name": "r2", + "version": 1, + "strategy_class": _R2Strategy, + "objective_scorer": objective_scorer, + } + + +class TestGetStepsAsyncRenameDeprecation: + """Pin the class-creation deprecation warning surface.""" + + def test_legacy_override_only_emits_deprecation(self) -> None: + with pytest.warns(DeprecationWarning, match=r"_get_atomic_attacks_async"): + + class _LegacyOnly(_R2ScenarioBase): + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + return [] + + assert _LegacyOnly is not None + + def test_new_override_only_does_not_warn(self) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + + class _NewOnly(_R2ScenarioBase): + async def _get_steps_async(self) -> list[AtomicAttack]: + return [] + + assert _NewOnly is not None + + def test_both_overrides_does_not_warn(self) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + + class _Both(_R2ScenarioBase): + async def _get_steps_async(self) -> list[AtomicAttack]: + return [] + + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + return [] + + assert _Both is not None + + def test_no_override_does_not_warn(self) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + + class _Neither(_R2ScenarioBase): + pass + + assert _Neither is not None + + +class TestGetStepsAsyncRenameDelegation: + """Pin the runtime delegation behavior of both directions.""" + + @pytest.mark.asyncio + @pytest.mark.usefixtures("patch_central_database") + async def test_legacy_override_reached_via_get_steps_async(self) -> None: + sentinel: list[AtomicAttack] = [MagicMock(spec=AtomicAttack)] + + with pytest.warns(DeprecationWarning): + + class _LegacyOnly(_R2ScenarioBase): + async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: + return sentinel + + scenario = _LegacyOnly(**_build_scenario_kwargs()) + result = await scenario._get_steps_async() + assert result is sentinel + + @pytest.mark.asyncio + @pytest.mark.usefixtures("patch_central_database") + async def test_new_override_reached_via_legacy_passthrough(self) -> None: + sentinel: list[AtomicAttack] = [MagicMock(spec=AtomicAttack)] + + class _NewOnly(_R2ScenarioBase): + async def _get_steps_async(self) -> list[AtomicAttack]: + return sentinel + + scenario = _NewOnly(**_build_scenario_kwargs()) + result = await scenario._get_atomic_attacks_async() + assert result is sentinel diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index 2d98425a4..06663aaaa 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -220,7 +220,7 @@ def default_dataset_config(cls) -> DatasetConfiguration: """Return the default dataset configuration for testing.""" return DatasetConfiguration() - async def _get_atomic_attacks_async(self): + async def _get_steps_async(self): return self._atomic_attacks_to_return @@ -1013,7 +1013,7 @@ async def test_baseline_objectives_match_atomic_attacks_under_max_dataset_size( config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) class StrategyScenario(ConcreteScenarioWithTrueFalseScorer): - async def _get_atomic_attacks_async(self): + async def _get_steps_async(self): groups_by_dataset = self._dataset_config.get_seed_attack_groups() all_seed_groups = [g for groups in groups_by_dataset.values() for g in groups] atomic_attacks = [ diff --git a/tests/unit/scenario/test_scenario_graph_execution.py b/tests/unit/scenario/test_scenario_graph_execution.py index 8d5a8f5c4..7eafa53b5 100644 --- a/tests/unit/scenario/test_scenario_graph_execution.py +++ b/tests/unit/scenario/test_scenario_graph_execution.py @@ -179,7 +179,7 @@ def get_default_strategy(cls): def default_dataset_config(cls) -> DatasetConfiguration: return DatasetConfiguration() - async def _get_atomic_attacks_async(self): + async def _get_steps_async(self): return self._atomic_attacks_to_return diff --git a/tests/unit/scenario/test_scenario_parameters.py b/tests/unit/scenario/test_scenario_parameters.py index 9c8b6fe6c..d3fb08cc5 100644 --- a/tests/unit/scenario/test_scenario_parameters.py +++ b/tests/unit/scenario/test_scenario_parameters.py @@ -53,7 +53,7 @@ def default_dataset_config(cls) -> DatasetConfiguration: def supported_parameters(cls) -> list[Parameter]: return list(params_to_declare) - async def _get_atomic_attacks_async(self): + async def _get_steps_async(self): return [] mock_scorer = MagicMock(spec=Scorer) diff --git a/tests/unit/scenario/test_scenario_partial_results.py b/tests/unit/scenario/test_scenario_partial_results.py index 38e41a256..07a0464c9 100644 --- a/tests/unit/scenario/test_scenario_partial_results.py +++ b/tests/unit/scenario/test_scenario_partial_results.py @@ -133,7 +133,7 @@ def __init__(self, *, atomic_attacks_to_return=None, objective_scorer=None, **kw super().__init__(strategy_class=strategy_class, objective_scorer=objective_scorer, **kwargs) self._test_atomic_attacks = atomic_attacks_to_return or [] - async def _get_atomic_attacks_async(self): + async def _get_steps_async(self): return self._test_atomic_attacks @classmethod diff --git a/tests/unit/scenario/test_scenario_retry.py b/tests/unit/scenario/test_scenario_retry.py index c88ff1385..b6a61c1ff 100644 --- a/tests/unit/scenario/test_scenario_retry.py +++ b/tests/unit/scenario/test_scenario_retry.py @@ -229,7 +229,7 @@ def default_dataset_config(cls) -> DatasetConfiguration: """Return the default dataset configuration for testing.""" return DatasetConfiguration() - async def _get_atomic_attacks_async(self): + async def _get_steps_async(self): return self._atomic_attacks_to_return diff --git a/tests/unit/scenario/test_waterfall.py b/tests/unit/scenario/test_waterfall.py index 4138411e0..2de4f145e 100644 --- a/tests/unit/scenario/test_waterfall.py +++ b/tests/unit/scenario/test_waterfall.py @@ -88,7 +88,7 @@ def default_dataset_config(cls) -> DatasetConfiguration: def _get_attack_technique_factories(self) -> dict[str, AttackTechniqueFactory]: return self._factories_override - async def _get_atomic_attacks_async(self): + async def _get_steps_async(self): return [] From d5a910d5cc378d9f062c0b54e487171b251f1d6c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 14:20:05 -0700 Subject: [PATCH 41/42] MAINT: R3 split StrategyGraph state from step state Renames the singular current_step abstraction to active_steps (tuple) to prepare for R4 concurrent dispatch. Adds active_steps property + bind_active_steps mutator on StrategyGraph; keeps current_step + bind_current_step as backward-compat shims. current_step emits DeprecationWarning only when ambiguous (len(active_steps) > 1). Migrates all four first-party callsites to bind_active_steps: linear_strategy_policy, Scenario._build_default_linear_policy, ScenarioPipeline._build_phase_action, and BroadSweepThenDeepDive sweep+deep actions. Adds tests/unit/scenario/test_active_steps_split.py (10 tests) covering default state, sequential binding, shim semantics, concurrent binding warning, and reset behavior. Per rlundeen review on #1767: surfaces R3 from the R-series rollout (R2 renamed the step builder; R4 wires concurrent dispatch on top of this split). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/scenario/composite/scenario_pipeline.py | 8 +- pyrit/scenario/core/scenario.py | 8 +- pyrit/scenario/core/strategy_graph.py | 80 ++++++++++-- .../scenarios/airt/sweep_then_deep_dive.py | 8 +- .../unit/scenario/test_active_steps_split.py | 122 ++++++++++++++++++ 5 files changed, 201 insertions(+), 25 deletions(-) create mode 100644 tests/unit/scenario/test_active_steps_split.py diff --git a/pyrit/scenario/composite/scenario_pipeline.py b/pyrit/scenario/composite/scenario_pipeline.py index cd356505a..f2611f43c 100644 --- a/pyrit/scenario/composite/scenario_pipeline.py +++ b/pyrit/scenario/composite/scenario_pipeline.py @@ -738,8 +738,8 @@ def _build_phase_action( """ Build the policy action for one phase. - The action binds the phase step as the current step (so the - orchestrator can resolve ``graph.current_step`` mid-dispatch), + The action binds the phase step as the active step (so the + orchestrator can resolve ``graph.active_steps`` mid-dispatch), invokes ``step.process_async``, and advances to the next integer state. The step's metadata (``phase_name``, ``phase_index``) is merged into the result's metadata so the orchestrator's logging / @@ -758,7 +758,7 @@ def _build_phase_action( async def _phase_action( graph: StrategyGraph[ScenarioStep, int], ) -> tuple[int, ScenarioStepResult | None]: - graph.bind_current_step(step=step) + graph.bind_active_steps(steps=(step,)) try: base_result = await step.process_async() # Pipeline diagnostic keys (``step_name``, ``phase_index``) must @@ -779,7 +779,7 @@ async def _phase_action( metadata=merged_metadata, ) finally: - graph.bind_current_step(step=None) + graph.bind_active_steps(steps=()) return next_state, result return _phase_action diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 7b0659658..a97f0cf1e 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -1329,11 +1329,11 @@ async def _action( _step: ScenarioStep = step, _next: int = index + 1, ) -> tuple[int, ScenarioStepResult | None]: - graph.bind_current_step(step=_step) + graph.bind_active_steps(steps=(_step,)) try: base_result = await _step.process_async() # Stamp ``step_name`` so the orchestrator can route the - # result without depending on ``graph.current_step`` + # result without depending on ``graph.active_steps`` # (cleared before yield). Caller metadata wins on # collision so steps remain authoritative. merged_metadata = {"step_name": _step.name, **base_result.metadata} @@ -1344,7 +1344,7 @@ async def _action( metadata=merged_metadata, ) finally: - graph.bind_current_step(step=None) + graph.bind_active_steps(steps=()) return _next, result actions[index] = _action @@ -1535,7 +1535,7 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) step_position = completed_count # Track the most recent step we attempted so a step-raised exception - # can still log the offending step's name. ``graph.current_step`` is + # can still log the offending step's name. ``graph.active_steps`` is # cleared in the policy action's ``finally`` before the exception # propagates, so it's not a reliable post-mortem source. last_attempted_step_name: str = "" diff --git a/pyrit/scenario/core/strategy_graph.py b/pyrit/scenario/core/strategy_graph.py index 5addfb80e..ad03869d8 100644 --- a/pyrit/scenario/core/strategy_graph.py +++ b/pyrit/scenario/core/strategy_graph.py @@ -19,11 +19,13 @@ from __future__ import annotations import logging -from collections.abc import AsyncIterator, Awaitable, Callable, Hashable, Mapping, Sequence +from collections.abc import AsyncIterator, Awaitable, Callable, Hashable, Iterable, Mapping, Sequence from dataclasses import dataclass, field from types import MappingProxyType from typing import TYPE_CHECKING, Generic, TypeVar +from pyrit.common.deprecation import print_deprecation_message + if TYPE_CHECKING: from pyrit.scenario.core.scenario_step import ScenarioStep, ScenarioStepResult @@ -127,9 +129,13 @@ class StrategyGraph(Generic[StepT, StateT]): result, and advances to the next state until a terminal state is reached. - The graph maintains ``current_state``, ``current_step``, and ``history`` + The graph maintains ``current_state``, ``active_steps``, and ``history`` so that retries can resume from the last persisted state without - replaying completed work. + replaying completed work. ``active_steps`` is a tuple so policy actions + that fan out to multiple concurrent steps (see R4 / asyncio.gather) have + a stable observable surface; single-step dispatch keeps the tuple at + length 0 or 1. The legacy singular ``current_step`` property is kept as + a backward-compat shim and is deprecated for direct use. """ def __init__( @@ -148,7 +154,7 @@ def __init__( """ self._policy = policy self._current_state: StateT = policy.initial_state - self._current_step: StepT | None = None + self._active_steps: tuple[StepT, ...] = () self._history: list[tuple[StateT, ScenarioStepResult]] = [] @property @@ -161,10 +167,41 @@ def current_state(self) -> StateT: """Return the graph's current state.""" return self._current_state + @property + def active_steps(self) -> tuple[StepT, ...]: + """ + Return the tuple of steps currently bound to the active state. + + Empty when no action is mid-dispatch. Length 1 for sequential + policies (the typical case). Length >1 when a concurrent policy + has fanned out to multiple steps via ``asyncio.gather``-style + execution (planned for R4). + """ + return self._active_steps + @property def current_step(self) -> StepT | None: - """Return the step bound to the current state, if the action set one.""" - return self._current_step + """ + Return the single active step, or ``None`` when idle. + + Deprecated: prefer :attr:`active_steps` for new code. This property + is preserved as a thin shim for legacy callers that assume sequential + dispatch (one step at a time). When the graph has fanned out to + multiple concurrent steps, this property returns the first one and + emits a :class:`DeprecationWarning` to flag the ambiguity. Removal + is planned for ``0.16.0``. + + Returns: + StepT | None: The first active step, or ``None`` if no step is + currently bound. + """ + if len(self._active_steps) > 1: + print_deprecation_message( + old_item="StrategyGraph.current_step (singular) while multiple steps are active", + new_item="StrategyGraph.active_steps (tuple) for concurrent dispatch", + removed_in="0.16.0", + ) + return self._active_steps[0] if self._active_steps else None @property def history(self) -> list[tuple[StateT, ScenarioStepResult]]: @@ -176,18 +213,35 @@ def is_terminal(self) -> bool: """Return ``True`` if the graph is in a terminal state.""" return self._policy.is_terminal(state=self._current_state) - def bind_current_step(self, *, step: StepT | None) -> None: + def bind_active_steps(self, *, steps: Iterable[StepT]) -> None: """ - Set the step bound to the current state. + Set the steps bound to the current state. Policy actions call this so external observers (e.g., the surrounding - ``Scenario``) can read ``graph.current_step`` while the action runs. + ``Scenario``) can read ``graph.active_steps`` while the action runs. + Pass an empty iterable to clear the binding. + + Args: + steps (Iterable[StepT]): The steps the action is about to execute, + or an empty iterable to clear. Concurrent policies pass the + full set of co-executing steps; sequential policies pass a + single-element iterable. + """ + self._active_steps = tuple(steps) + + def bind_current_step(self, *, step: StepT | None) -> None: + """ + Set the single step bound to the current state. + + Backward-compat shim over :meth:`bind_active_steps`. Sequential + policies that bind one step at a time may continue to use this name. + Passing ``None`` clears the binding. Args: step (StepT | None): The step the action is about to execute, or ``None`` to clear. """ - self._current_step = step + self.bind_active_steps(steps=() if step is None else (step,)) def reset(self) -> None: """ @@ -197,7 +251,7 @@ def reset(self) -> None: from the last persisted state. """ self._current_state = self._policy.initial_state - self._current_step = None + self._active_steps = () self._history = [] async def event_loop_async(self) -> AsyncIterator[ScenarioStepResult]: @@ -269,11 +323,11 @@ async def _action( _step: ScenarioStep = step, _next: int = index + 1, ) -> tuple[int, ScenarioStepResult | None]: - graph.bind_current_step(step=_step) + graph.bind_active_steps(steps=(_step,)) try: result = await _step.process_async() finally: - graph.bind_current_step(step=None) + graph.bind_active_steps(steps=()) return _next, result actions[index] = _action diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index 0f73a5171..a33abca7f 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -732,7 +732,7 @@ def _build_branching_policy( async def _sweep_action( graph: StrategyGraph[ScenarioStep, SweepThenDeepDiveState], ) -> tuple[SweepThenDeepDiveState, ScenarioStepResult | None]: - graph.bind_current_step(step=sweep_step) + graph.bind_active_steps(steps=(sweep_step,)) try: base_result = await sweep_step.process_async() # Update the closure-shared weak-categories set so the @@ -746,7 +746,7 @@ async def _sweep_action( metadata=merged_metadata, ) finally: - graph.bind_current_step(step=None) + graph.bind_active_steps(steps=()) next_state = ( SweepThenDeepDiveState.DEEP_DIVING @@ -758,7 +758,7 @@ async def _sweep_action( async def _deep_dive_action( graph: StrategyGraph[ScenarioStep, SweepThenDeepDiveState], ) -> tuple[SweepThenDeepDiveState, ScenarioStepResult | None]: - graph.bind_current_step(step=deep_dive_step) + graph.bind_active_steps(steps=(deep_dive_step,)) try: base_result = await deep_dive_step.process_async() merged_metadata = {"step_name": deep_dive_step.name, **base_result.metadata} @@ -769,7 +769,7 @@ async def _deep_dive_action( metadata=merged_metadata, ) finally: - graph.bind_current_step(step=None) + graph.bind_active_steps(steps=()) return SweepThenDeepDiveState.COMPLETE, result actions: dict[SweepThenDeepDiveState, PolicyAction[ScenarioStep, SweepThenDeepDiveState]] = { diff --git a/tests/unit/scenario/test_active_steps_split.py b/tests/unit/scenario/test_active_steps_split.py new file mode 100644 index 000000000..6d16249b9 --- /dev/null +++ b/tests/unit/scenario/test_active_steps_split.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the R3 ``StrategyGraph.active_steps`` state-split. + +The R3 refactor widens ``StrategyGraph``'s singular ``current_step`` cursor +into a tuple-valued ``active_steps`` field so a future concurrent-dispatch +policy (R4) has a stable, observable surface for fan-out execution. The +legacy ``current_step`` property is preserved as a backward-compat shim and +deprecated for direct use; it emits a ``DeprecationWarning`` only when the +graph has fanned out to more than one step (where collapsing to a single +``StepT | None`` would be ambiguous). +""" + +from __future__ import annotations + +import warnings +from unittest.mock import MagicMock + +import pytest + +from pyrit.scenario.core.scenario_step import ScenarioStep +from pyrit.scenario.core.strategy_graph import StrategyGraph, StrategyPolicy + + +def _build_idle_graph() -> StrategyGraph[ScenarioStep, int]: + """Build a minimal graph parked at its terminal state so binding tests don't drive the loop.""" + + async def _noop(graph: StrategyGraph[ScenarioStep, int]) -> tuple[int, None]: + return 1, None + + policy: StrategyPolicy[ScenarioStep, int] = StrategyPolicy( + actions={0: _noop}, + initial_state=0, + terminal_states=frozenset({1}), + ) + return StrategyGraph(policy=policy) + + +class TestActiveStepsDefault: + """The graph starts idle with an empty active_steps tuple.""" + + def test_initial_active_steps_is_empty_tuple(self) -> None: + graph = _build_idle_graph() + assert graph.active_steps == () + assert isinstance(graph.active_steps, tuple) + + def test_initial_current_step_is_none(self) -> None: + graph = _build_idle_graph() + assert graph.current_step is None + + +class TestBindActiveStepsSequential: + """Binding a single step exposes the same value through both surfaces.""" + + def test_single_step_visible_via_active_steps(self) -> None: + graph = _build_idle_graph() + step = MagicMock(spec=ScenarioStep) + graph.bind_active_steps(steps=(step,)) + assert graph.active_steps == (step,) + + def test_single_step_visible_via_current_step_without_warning(self) -> None: + graph = _build_idle_graph() + step = MagicMock(spec=ScenarioStep) + graph.bind_active_steps(steps=(step,)) + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + assert graph.current_step is step + + def test_clear_via_empty_tuple(self) -> None: + graph = _build_idle_graph() + graph.bind_active_steps(steps=(MagicMock(spec=ScenarioStep),)) + graph.bind_active_steps(steps=()) + assert graph.active_steps == () + assert graph.current_step is None + + +class TestBindCurrentStepShim: + """The legacy singular binder still works and routes through bind_active_steps.""" + + def test_bind_step_sets_active_steps_to_singleton(self) -> None: + graph = _build_idle_graph() + step = MagicMock(spec=ScenarioStep) + graph.bind_current_step(step=step) + assert graph.active_steps == (step,) + assert graph.current_step is step + + def test_bind_none_clears_active_steps(self) -> None: + graph = _build_idle_graph() + graph.bind_current_step(step=MagicMock(spec=ScenarioStep)) + graph.bind_current_step(step=None) + assert graph.active_steps == () + assert graph.current_step is None + + +class TestConcurrentBinding: + """Multi-step binding is the R3 contract for R4's concurrent dispatch.""" + + def test_multiple_steps_visible_via_active_steps(self) -> None: + graph = _build_idle_graph() + steps = (MagicMock(spec=ScenarioStep), MagicMock(spec=ScenarioStep), MagicMock(spec=ScenarioStep)) + graph.bind_active_steps(steps=steps) + assert graph.active_steps == steps + + def test_current_step_warns_on_concurrent_dispatch(self) -> None: + graph = _build_idle_graph() + steps = (MagicMock(spec=ScenarioStep), MagicMock(spec=ScenarioStep)) + graph.bind_active_steps(steps=steps) + with pytest.warns(DeprecationWarning, match=r"current_step"): + first = graph.current_step + assert first is steps[0] + + +class TestResetClearsActiveSteps: + """reset() returns the graph to its idle state, including active_steps.""" + + def test_reset_empties_active_steps(self) -> None: + graph = _build_idle_graph() + graph.bind_active_steps(steps=(MagicMock(spec=ScenarioStep), MagicMock(spec=ScenarioStep))) + graph.reset() + assert graph.active_steps == () + assert graph.current_step is None From 28428793815f439ae3b9ac0d46fe410eb0238c6d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 21 May 2026 14:29:34 -0700 Subject: [PATCH 42/42] MAINT: R4 concurrent deep-dive dispatch via asyncio.gather Adds max_step_concurrency (int, default 1) to BroadSweepThenDeepDive and FilteredDeepDiveStep. Default 1 preserves pre-R4 sequential semantics bit-for-bit. >1 wraps the per-atomic dispatch in an asyncio.Semaphore and awaits via asyncio.gather; dispatched_categories and attack_results retain input order because gather preserves it. Validates inputs at both layers (>= 1) so wizard / programmatic callers fail fast on bogus values. Stamps the effective concurrency cap into ScenarioStepResult.metadata['max_step_concurrency'] for downstream diagnostics. Surfaces the new scalar role through input_schema() so the wizard can elicit it (4 roles -> 5: 3 OPAQUE + 2 SCALAR). Adds tests/unit/scenario/scenarios/airt/test_concurrent_deep_dive.py (13 tests) covering: validation, order preservation, empty short-circuit, peak in-flight observation via asyncio.Event gating, and semaphore upper-bound enforcement. Updates test_sweep_then_deep_dive_input_schema.py for the 5-role schema. Per rlundeen review on #1767: R4 is the concrete concurrent-dispatch payload made possible by R3's active_steps split. Per-atomic active_steps publication (graph.mark_step_running per the plan's example) is a follow-up that requires the StepStatus sidecar abstraction that R3 explicitly deferred. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../scenarios/airt/sweep_then_deep_dive.py | 83 +++- .../airt/test_concurrent_deep_dive.py | 397 ++++++++++++++++++ .../test_sweep_then_deep_dive_input_schema.py | 15 +- 3 files changed, 482 insertions(+), 13 deletions(-) create mode 100644 tests/unit/scenario/scenarios/airt/test_concurrent_deep_dive.py diff --git a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py index a33abca7f..fd5bb5610 100644 --- a/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py +++ b/pyrit/scenario/scenarios/airt/sweep_then_deep_dive.py @@ -33,6 +33,7 @@ from __future__ import annotations +import asyncio import logging from enum import Enum from typing import TYPE_CHECKING, ClassVar, Optional, cast @@ -313,6 +314,7 @@ def __init__( atomic_attacks: Sequence[AtomicAttack], weak_categories_ref: Callable[[], set[str]], max_concurrency: int = 1, + max_step_concurrency: int = 1, ) -> None: """ Initialize the filtered deep-dive step. @@ -327,19 +329,34 @@ def __init__( set by reference) lets the sweep step build its weak-set fresh on each scenario attempt without the deep-dive step holding a stale reference. - max_concurrency (int): Forwarded to each ``AtomicAttack.run_async``. + max_concurrency (int): Forwarded to each ``AtomicAttack.run_async``; + controls the per-atomic objective-level fan-out (the existing + ``max_concurrency`` plumbing — unchanged semantics). + max_step_concurrency (int): R4 — controls how many wrapped atomics + may dispatch concurrently against the model. Default ``1`` + preserves the sequential ``for atomic in self._atomics`` semantics + bit-for-bit. ``>1`` wraps the per-atomic dispatch in an + ``asyncio.Semaphore`` and runs them under ``asyncio.gather``; + ``dispatched_categories`` / ``attack_results`` retain input + order because gather preserves it. Raises: ValueError: If ``atomic_attacks`` is empty. + ValueError: If ``max_step_concurrency`` is less than ``1``. """ if not atomic_attacks: raise ValueError("FilteredDeepDiveStep requires at least one atomic attack.") + if max_step_concurrency < 1: + raise ValueError( + f"max_step_concurrency must be >= 1, got {max_step_concurrency}.", + ) self.name = atomic_attack_name self.outputs = list(self._OUTPUTS) self._atomics = list(atomic_attacks) self._weak_categories_ref = weak_categories_ref self._max_concurrency = max_concurrency + self._max_step_concurrency = max_step_concurrency self.atomic_attack_name = atomic_attack_name self.display_group = atomic_attack_name @@ -371,37 +388,61 @@ async def process_async(self) -> ScenarioStepResult: """ Run each wrapped atomic conditionally on its category being flagged. + With ``max_step_concurrency == 1`` (the default) the atomics dispatch + sequentially in input order — identical to the pre-R4 behavior. + With ``max_step_concurrency > 1`` the eligible atomics are dispatched + concurrently under an ``asyncio.Semaphore`` and awaited via + ``asyncio.gather``; the returned ``dispatched_categories`` and + ``attack_results`` lists retain input order because ``gather`` + preserves it. + Returns: ScenarioStepResult: Outcome is always ``"done"``. The aggregated ``attack_results`` contain results from every atomic that was actually dispatched. ``metadata['skipped_categories']`` lists categories that were not in the weak set; ``metadata['dispatched_categories']`` lists the categories - actually exercised. + actually exercised; ``metadata['max_step_concurrency']`` + records the concurrency cap used for this attempt. """ weak = self._weak_categories_ref() - attack_results: list[AttackResult] = [] dispatched: list[str] = [] skipped: list[str] = [] + eligible: list[tuple[str, AtomicAttack]] = [] for atomic in self._atomics: category = atomic.display_group or atomic.atomic_attack_name if category not in weak: skipped.append(category) continue - executor_result = await atomic.run_async( - max_concurrency=self._max_concurrency, - return_partial_on_failure=True, - ) - attack_results.extend(executor_result.completed_results) + eligible.append((category, atomic)) dispatched.append(category) + attack_results: list[AttackResult] = [] + if eligible: + semaphore = asyncio.Semaphore(self._max_step_concurrency) + + async def _run_one(atomic: AtomicAttack) -> list[AttackResult]: + async with semaphore: + executor_result = await atomic.run_async( + max_concurrency=self._max_concurrency, + return_partial_on_failure=True, + ) + return list(executor_result.completed_results) + + per_atomic_results = await asyncio.gather( + *(_run_one(atomic) for _, atomic in eligible), + ) + for results in per_atomic_results: + attack_results.extend(results) + return ScenarioStepResult( outcome="done", attack_results=attack_results, metadata={ "dispatched_categories": dispatched, "skipped_categories": skipped, + "max_step_concurrency": self._max_step_concurrency, }, ) @@ -464,6 +505,7 @@ def __init__( deep_dive_atomic_attacks: Sequence[AtomicAttack], outcome_scorer: OutcomeScorer, weakness_label: str = "safety_violation", + max_step_concurrency: int = 1, objective_scorer: Optional[TrueFalseScorer] = None, scenario_result_id: Optional[str] = None, ) -> None: @@ -483,6 +525,13 @@ def __init__( weakness_label (str): The label emitted by ``outcome_scorer`` that signals a category breach. Defaults to ``"safety_violation"``. + max_step_concurrency (int): R4 — how many deep-dive atomics may + dispatch concurrently against the model. Default ``1`` + preserves the sequential pre-R4 behavior bit-for-bit. ``>1`` + wraps the deep-dive fan-out in an ``asyncio.Semaphore`` and + awaits via ``asyncio.gather``. Independent of the per-atomic + ``max_concurrency`` plumbing (which controls objective-level + fan-out inside each ``AtomicAttack.run_async``). objective_scorer (TrueFalseScorer | None): Forwarded to the base ``Scenario``. Defaults to ``outcome_scorer.wrapped_scorer`` cast to ``TrueFalseScorer`` so dataset config bootstrap @@ -494,9 +543,14 @@ def __init__( ValueError: If ``deep_dive_atomic_attacks`` is empty. ValueError: If ``weakness_label`` is not declared as one of ``outcome_scorer.outcomes``. + ValueError: If ``max_step_concurrency`` is less than ``1``. """ if not deep_dive_atomic_attacks: raise ValueError("BroadSweepThenDeepDive requires at least one deep_dive_atomic_attack.") + if max_step_concurrency < 1: + raise ValueError( + f"max_step_concurrency must be >= 1, got {max_step_concurrency}.", + ) # Fail fast: the inner ``CategoryAggregatingSweepStep`` performs the # same check, but only inside ``_build_execution_graph`` (called from @@ -515,6 +569,7 @@ def __init__( self._deep_dive_atomics: list[AtomicAttack] = list(deep_dive_atomic_attacks) self._outcome_scorer = outcome_scorer self._weakness_label = weakness_label + self._max_step_concurrency = max_step_concurrency # Shared mutable handle the sweep step updates and the deep-dive step # reads. Reset on each ``run_async`` via ``_build_execution_graph``. @@ -595,6 +650,17 @@ def input_schema(cls) -> list[RoleDescriptor]: default="safety_violation", required=False, ), + RoleDescriptor( + name="max_step_concurrency", + description=( + "R4 — how many deep-dive atomics may dispatch concurrently against the model. " + "Default 1 preserves sequential dispatch; >1 fans out via asyncio.gather under a semaphore." + ), + tag=RoleTag.SCALAR, + param_type=int, + default=1, + required=False, + ), ] @classmethod @@ -697,6 +763,7 @@ def _build_execution_graph( # ty: ignore[invalid-method-override] atomic_attacks=self._deep_dive_atomics, weak_categories_ref=lambda: self._weak_categories, max_concurrency=self._max_concurrency, + max_step_concurrency=self._max_step_concurrency, ) policy = self._build_branching_policy(sweep_step=sweep_step, deep_dive_step=deep_dive_step) diff --git a/tests/unit/scenario/scenarios/airt/test_concurrent_deep_dive.py b/tests/unit/scenario/scenarios/airt/test_concurrent_deep_dive.py new file mode 100644 index 000000000..59aa84e0a --- /dev/null +++ b/tests/unit/scenario/scenarios/airt/test_concurrent_deep_dive.py @@ -0,0 +1,397 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +R4 — concurrent deep-dive dispatch in ``FilteredDeepDiveStep``. + +Pins the contract that adding ``max_step_concurrency > 1`` to +``FilteredDeepDiveStep`` (and the ``BroadSweepThenDeepDive`` scenario that +plumbs it) fans out the per-atomic dispatch through ``asyncio.gather`` under +a semaphore while preserving: + +* input order of ``dispatched_categories`` and ``attack_results``; +* short-circuit on empty weak-set; +* bit-for-bit semantics at the default ``max_step_concurrency=1``; +* fail-fast validation on bogus values. + +The hardest contract to pin is "multiple atomics are genuinely in-flight at +the same time". That is asserted via an ``asyncio.Event``-gated fake +``AtomicAttack.run_async`` that records the peak observed in-flight count +and only releases once the cap is reached. With ``max_step_concurrency=N`` +and ``N`` candidates, all ``N`` must reach the gate before any completes. +""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import MagicMock, PropertyMock + +import pytest + +from pyrit.executor.attack.core import AttackExecutorResult +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import AttackOutcome, AttackResult +from pyrit.scenario.core import AtomicAttack +from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + FilteredDeepDiveStep, +) + + +def _attack_result(*, conversation_id: str, objective: str) -> AttackResult: + result = AttackResult( + conversation_id=conversation_id, + objective=objective, + outcome=AttackOutcome.SUCCESS, + executed_turns=1, + ) + result.atomic_attack_identifier = ComponentIdentifier( + class_name="MockAttack", + class_module="tests.unit.scenario.scenarios.airt", + params={"name": conversation_id}, + ) + return result + + +def _make_atomic_mock( + *, + name: str, + display_group: str, + attack_results: list[AttackResult], +) -> MagicMock: + attack = MagicMock(spec=AtomicAttack) + attack.atomic_attack_name = name + attack.display_group = display_group + type(attack).objectives = PropertyMock(return_value=[r.objective for r in attack_results]) + type(attack).seed_groups = PropertyMock(return_value=[]) + attack.get_identifier.return_value = ComponentIdentifier( + class_name="AtomicAttack", + class_module="tests.unit.scenario.scenarios.airt", + params={"name": name}, + ) + + async def _fake_run(*args: Any, **kwargs: Any) -> AttackExecutorResult: + return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) + + attack.run_async = MagicMock(side_effect=_fake_run) + return attack + + +def _make_gated_atomic_mock( + *, + name: str, + display_group: str, + attack_results: list[AttackResult], + inflight_counter: dict[str, int], + peak: dict[str, int], + release_event: asyncio.Event, +) -> MagicMock: + """Build an AtomicAttack mock whose run_async blocks on ``release_event``. + + Tracks how many invocations are simultaneously past the gate (in-flight) + so the test can assert the peak concurrency reached. + """ + attack = MagicMock(spec=AtomicAttack) + attack.atomic_attack_name = name + attack.display_group = display_group + type(attack).objectives = PropertyMock(return_value=[r.objective for r in attack_results]) + type(attack).seed_groups = PropertyMock(return_value=[]) + attack.get_identifier.return_value = ComponentIdentifier( + class_name="AtomicAttack", + class_module="tests.unit.scenario.scenarios.airt", + params={"name": name}, + ) + + async def _gated_run(*args: Any, **kwargs: Any) -> AttackExecutorResult: + inflight_counter["n"] += 1 + peak["n"] = max(peak["n"], inflight_counter["n"]) + try: + await release_event.wait() + finally: + inflight_counter["n"] -= 1 + return AttackExecutorResult(completed_results=attack_results, incomplete_objectives=[]) + + attack.run_async = MagicMock(side_effect=_gated_run) + return attack + + +@pytest.mark.usefixtures("patch_central_database") +class TestFilteredDeepDiveStepConcurrency: + """R4: deep-dive step fans out under asyncio.Semaphore when concurrency > 1.""" + + def test_init_rejects_zero_max_step_concurrency(self) -> None: + a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + with pytest.raises(ValueError, match=r"max_step_concurrency must be >= 1"): + FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[a], + weak_categories_ref=lambda: {"cat-a"}, + max_step_concurrency=0, + ) + + def test_init_rejects_negative_max_step_concurrency(self) -> None: + a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + with pytest.raises(ValueError, match=r"max_step_concurrency must be >= 1"): + FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[a], + weak_categories_ref=lambda: {"cat-a"}, + max_step_concurrency=-3, + ) + + def test_default_concurrency_is_one(self) -> None: + a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[a], + weak_categories_ref=lambda: {"cat-a"}, + ) + assert step._max_step_concurrency == 1 + + async def test_dispatch_order_preserved_with_concurrency_gt_one(self) -> None: + result_a = _attack_result(conversation_id="da", objective="oa") + result_b = _attack_result(conversation_id="db", objective="ob") + result_c = _attack_result(conversation_id="dc", objective="oc") + atomic_a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[result_a]) + atomic_b = _make_atomic_mock(name="b", display_group="cat-b", attack_results=[result_b]) + atomic_c = _make_atomic_mock(name="c", display_group="cat-c", attack_results=[result_c]) + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[atomic_a, atomic_b, atomic_c], + weak_categories_ref=lambda: {"cat-a", "cat-b", "cat-c"}, + max_step_concurrency=3, + ) + + step_result = await step.process_async() + + assert step_result.metadata["dispatched_categories"] == ["cat-a", "cat-b", "cat-c"] + assert step_result.metadata["skipped_categories"] == [] + assert step_result.attack_results == [result_a, result_b, result_c] + assert step_result.metadata["max_step_concurrency"] == 3 + + async def test_skipped_categories_retain_input_order(self) -> None: + result_b = _attack_result(conversation_id="db", objective="ob") + atomic_a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + atomic_b = _make_atomic_mock(name="b", display_group="cat-b", attack_results=[result_b]) + atomic_c = _make_atomic_mock(name="c", display_group="cat-c", attack_results=[]) + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[atomic_a, atomic_b, atomic_c], + weak_categories_ref=lambda: {"cat-b"}, + max_step_concurrency=2, + ) + + step_result = await step.process_async() + + atomic_a.run_async.assert_not_called() + atomic_b.run_async.assert_called_once() + atomic_c.run_async.assert_not_called() + assert step_result.metadata["dispatched_categories"] == ["cat-b"] + assert step_result.metadata["skipped_categories"] == ["cat-a", "cat-c"] + assert step_result.attack_results == [result_b] + + async def test_empty_weak_set_short_circuits_without_dispatch(self) -> None: + atomic_a = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + atomic_b = _make_atomic_mock(name="b", display_group="cat-b", attack_results=[]) + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=[atomic_a, atomic_b], + weak_categories_ref=lambda: set(), + max_step_concurrency=4, + ) + + step_result = await step.process_async() + + atomic_a.run_async.assert_not_called() + atomic_b.run_async.assert_not_called() + assert step_result.metadata["dispatched_categories"] == [] + assert step_result.metadata["skipped_categories"] == ["cat-a", "cat-b"] + assert step_result.attack_results == [] + assert step_result.metadata["max_step_concurrency"] == 4 + + async def test_concurrent_dispatch_observes_n_simultaneously_inflight(self) -> None: + """All N eligible atomics must reach the gate before any completes when concurrency >= N.""" + inflight: dict[str, int] = {"n": 0} + peak: dict[str, int] = {"n": 0} + release = asyncio.Event() + + atomics = [ + _make_gated_atomic_mock( + name=f"a{i}", + display_group=f"cat-{i}", + attack_results=[_attack_result(conversation_id=f"d{i}", objective=f"o{i}")], + inflight_counter=inflight, + peak=peak, + release_event=release, + ) + for i in range(4) + ] + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=atomics, + weak_categories_ref=lambda: {f"cat-{i}" for i in range(4)}, + max_step_concurrency=4, + ) + + gather_task = asyncio.create_task(step.process_async()) + # Yield until all 4 atomics have reached the gate. + for _ in range(200): + if inflight["n"] >= 4: + break + await asyncio.sleep(0) + assert inflight["n"] == 4, f"expected 4 in-flight, got {inflight['n']}" + release.set() + result = await gather_task + + assert peak["n"] == 4 + assert result.metadata["max_step_concurrency"] == 4 + assert result.metadata["dispatched_categories"] == [f"cat-{i}" for i in range(4)] + + async def test_semaphore_bounds_inflight_below_candidate_count(self) -> None: + """With ``max_step_concurrency=2`` and 4 candidates, peak in-flight must be <= 2.""" + inflight: dict[str, int] = {"n": 0} + peak: dict[str, int] = {"n": 0} + release = asyncio.Event() + + atomics = [ + _make_gated_atomic_mock( + name=f"a{i}", + display_group=f"cat-{i}", + attack_results=[_attack_result(conversation_id=f"d{i}", objective=f"o{i}")], + inflight_counter=inflight, + peak=peak, + release_event=release, + ) + for i in range(4) + ] + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=atomics, + weak_categories_ref=lambda: {f"cat-{i}" for i in range(4)}, + max_step_concurrency=2, + ) + + gather_task = asyncio.create_task(step.process_async()) + # Yield long enough for the semaphore to have admitted 2 — and only 2. + for _ in range(200): + if inflight["n"] >= 2: + break + await asyncio.sleep(0) + # Spin a few more iterations to confirm the cap holds. + for _ in range(50): + await asyncio.sleep(0) + + assert inflight["n"] == 2 + assert peak["n"] == 2 + release.set() + result = await gather_task + # After release, every atomic completed. + assert len(result.attack_results) == 4 + + async def test_default_concurrency_one_is_strictly_sequential(self) -> None: + """With default ``max_step_concurrency=1``, only one atomic may be in-flight at a time.""" + inflight: dict[str, int] = {"n": 0} + peak: dict[str, int] = {"n": 0} + # Use a per-call release pattern: release immediately so the sequential + # case still completes. Peak should never exceed 1. + release = asyncio.Event() + release.set() + + atomics = [ + _make_gated_atomic_mock( + name=f"a{i}", + display_group=f"cat-{i}", + attack_results=[_attack_result(conversation_id=f"d{i}", objective=f"o{i}")], + inflight_counter=inflight, + peak=peak, + release_event=release, + ) + for i in range(3) + ] + + step = FilteredDeepDiveStep( + atomic_attack_name="deep", + atomic_attacks=atomics, + weak_categories_ref=lambda: {f"cat-{i}" for i in range(3)}, + # default max_step_concurrency=1 + ) + + result = await step.process_async() + assert peak["n"] == 1 + assert result.metadata["max_step_concurrency"] == 1 + assert len(result.attack_results) == 3 + + +@pytest.mark.usefixtures("patch_central_database") +class TestBroadSweepThenDeepDiveConcurrencyPlumbing: + """The scenario constructor must validate + forward ``max_step_concurrency``.""" + + def test_scenario_rejects_zero_max_step_concurrency(self) -> None: + from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + BroadSweepThenDeepDive, + ) + + atomic = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + scorer = MagicMock() + scorer.outcomes = {"safety_violation", "safe"} + + with pytest.raises(ValueError, match=r"max_step_concurrency must be >= 1"): + BroadSweepThenDeepDive( + sweep_atomic_attack=atomic, + deep_dive_atomic_attacks=[atomic], + outcome_scorer=scorer, + max_step_concurrency=0, + ) + + def test_scenario_rejects_negative_max_step_concurrency(self) -> None: + from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + BroadSweepThenDeepDive, + ) + + atomic = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + scorer = MagicMock() + scorer.outcomes = {"safety_violation", "safe"} + + with pytest.raises(ValueError, match=r"max_step_concurrency must be >= 1"): + BroadSweepThenDeepDive( + sweep_atomic_attack=atomic, + deep_dive_atomic_attacks=[atomic], + outcome_scorer=scorer, + max_step_concurrency=-1, + ) + + def test_scenario_default_max_step_concurrency_is_one(self) -> None: + from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + BroadSweepThenDeepDive, + ) + + atomic = _make_atomic_mock(name="a", display_group="cat-a", attack_results=[]) + scorer = MagicMock() + scorer.outcomes = {"safety_violation", "safe"} + + scenario = BroadSweepThenDeepDive( + sweep_atomic_attack=atomic, + deep_dive_atomic_attacks=[atomic], + outcome_scorer=scorer, + ) + assert scenario._max_step_concurrency == 1 + + def test_input_schema_advertises_max_step_concurrency_scalar(self) -> None: + from pyrit.scenario.core.input_schema import RoleTag + from pyrit.scenario.scenarios.airt.sweep_then_deep_dive import ( + BroadSweepThenDeepDive, + ) + + schema = BroadSweepThenDeepDive.input_schema() + names = [r.name for r in schema] + assert "max_step_concurrency" in names + role = next(r for r in schema if r.name == "max_step_concurrency") + assert role.tag is RoleTag.SCALAR + assert role.param_type is int + assert role.default == 1 + assert role.required is False diff --git a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py index 39da38ebd..b77bf1cf6 100644 --- a/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py +++ b/tests/unit/scenario/scenarios/airt/test_sweep_then_deep_dive_input_schema.py @@ -14,11 +14,15 @@ class TestBroadSweepThenDeepDiveInputSchema: - """``BroadSweepThenDeepDive.input_schema()`` declares 3 OPAQUE roles + 1 SCALAR.""" + """``BroadSweepThenDeepDive.input_schema()`` declares 3 OPAQUE roles + 2 SCALAR. - def test_returns_four_roles(self): + R4 added ``max_step_concurrency`` as a second SCALAR role so the wizard + can expose the deep-dive concurrency knob. + """ + + def test_returns_five_roles(self): schema = BroadSweepThenDeepDive.input_schema() - assert len(schema) == 4 + assert len(schema) == 5 assert all(isinstance(role, RoleDescriptor) for role in schema) def test_role_names_match_constructor_inputs(self): @@ -28,16 +32,17 @@ def test_role_names_match_constructor_inputs(self): "deep_dive_atomic_attacks", "outcome_scorer", "weakness_label", + "max_step_concurrency", ] - def test_three_opaque_one_scalar(self): + def test_three_opaque_two_scalar(self): by_tag: dict[RoleTag, list[str]] = {tag: [] for tag in RoleTag} for role in BroadSweepThenDeepDive.input_schema(): by_tag[role.tag].append(role.name) assert sorted(by_tag[RoleTag.OPAQUE]) == sorted( ["sweep_atomic_attack", "deep_dive_atomic_attacks", "outcome_scorer"] ) - assert by_tag[RoleTag.SCALAR] == ["weakness_label"] + assert sorted(by_tag[RoleTag.SCALAR]) == sorted(["weakness_label", "max_step_concurrency"]) # No other tags present. for tag, names in by_tag.items(): if tag not in {RoleTag.OPAQUE, RoleTag.SCALAR}: