diff --git a/doc/bibliography.md b/doc/bibliography.md index ac3dbf8b9..b2c32e7e0 100644 --- a/doc/bibliography.md +++ b/doc/bibliography.md @@ -5,6 +5,6 @@ All academic papers, research blogs, and technical reports referenced throughout :::{dropdown} Citation Keys :class: hidden-citations -[@aakanksha2024multilingual; @adversaai2023universal; @andriushchenko2024tense; @anthropic2024manyshot; @aqrawi2024singleturncrescendo; @bethany2024mathprompt; @bhardwaj2023harmfulqa; @bhardwaj2024homer; @bryan2025agentictaxonomy; @bullwinkel2025airtlessons; @bullwinkel2025repeng; @bullwinkel2026trigger; @chao2023pair; @chao2024jailbreakbench; @cui2024orbench; @darkbench2025; @derczynski2024garak; @ding2023wolf; @embracethered2024unicode; @embracethered2025sneakybits; @ghosh2025aegis; @gupta2024walledeval; @haider2024phi3safety; @han2024medsafetybench; @hines2024spotlighting; @ji2023beavertails; @ji2024pkusaferlhf; @jiang2025sosbench; @jones2025computeruse; @kingma2014adam; @li2024saladbench; @li2024wmdp; @lin2023toxicchat; @liu2024flipattack; @lopez2024pyrit; @lv2024codechameleon; @mazeika2023tdc; @mazeika2024harmbench; @mckee2024transparency; @mehrotra2023tap; @microsoft2024skeletonkey; @palaskar2025vlsu; @pfohl2024equitymedqa; @promptfoo2025ccp; @robustintelligence2024bypass; @roccia2024promptintel; @rottger2023xstest; @rottger2025msts; @russinovich2024crescendo; @russinovich2025price; @scheuerman2025transphobia; @shaikh2022second; @shayegani2025computeruse; @shen2023donotanything; @sheshadri2024lat; @stok2023ansi; @tan2026comicjailbreak; @tang2025multilingual; @tedeschi2024alert; @vantaylor2024socialbias; @vidgen2023simplesafetytests; @vidgen2024ailuminate; @wang2023decodingtrust; @wang2023donotanswer; @wei2023jailbroken; @xie2024sorrybench; @yu2023gptfuzzer; @yuan2023cipherchat; @zeng2024persuasion; @zhang2024cbtbench; @zou2023gcg] +[@aakanksha2024multilingual; @adversaai2023universal; @andriushchenko2024tense; @anthropic2024manyshot; @aqrawi2024singleturncrescendo; @bethany2024mathprompt; @bhardwaj2023harmfulqa; @bhardwaj2024homer; @bryan2025agentictaxonomy; @bullwinkel2025airtlessons; @bullwinkel2025repeng; @bullwinkel2026trigger; @chao2023pair; @chao2024jailbreakbench; @cui2024orbench; @darkbench2025; @derczynski2024garak; @ding2023wolf; @embracethered2024unicode; @embracethered2025sneakybits; @ghosh2025aegis; @gupta2024walledeval; @haider2024phi3safety; @han2024medsafetybench; @hines2024spotlighting; @ji2023beavertails; @ji2024pkusaferlhf; @jiang2025sosbench; @jones2025computeruse; @kingma2014adam; @li2024saladbench; @li2024wmdp; @lin2023toxicchat; @liu2024flipattack; @lopez2024pyrit; @lv2024codechameleon; @mazeika2023tdc; @mazeika2024harmbench; @mckee2024transparency; @mehrotra2023tap; @microsoft2024skeletonkey; @li2024mossbench; @palaskar2025vlsu; @pfohl2024equitymedqa; @promptfoo2025ccp; @robustintelligence2024bypass; @roccia2024promptintel; @rottger2023xstest; @rottger2025msts; @russinovich2024crescendo; @russinovich2025price; @scheuerman2025transphobia; @shaikh2022second; @shayegani2025computeruse; @shen2023donotanything; @sheshadri2024lat; @stok2023ansi; @tan2026comicjailbreak; @tang2025multilingual; @tedeschi2024alert; @vantaylor2024socialbias; @vidgen2023simplesafetytests; @vidgen2024ailuminate; @wang2023decodingtrust; @wang2023donotanswer; @wei2023jailbroken; @xie2024sorrybench; @yu2023gptfuzzer; @yuan2023cipherchat; @zeng2024persuasion; @zhang2024cbtbench; @zou2023gcg] ::: diff --git a/doc/code/datasets/1_loading_datasets.ipynb b/doc/code/datasets/1_loading_datasets.ipynb index 4dc1c6ab9..49be803d9 100644 --- a/doc/code/datasets/1_loading_datasets.ipynb +++ b/doc/code/datasets/1_loading_datasets.ipynb @@ -28,6 +28,7 @@ "JailbreakBench [@chao2024jailbreakbench],\n", "LLM-LAT [@sheshadri2024lat],\n", "MedSafetyBench [@han2024medsafetybench],\n", + "MOSSBench [@li2024mossbench],\n", "Multilingual Alignment Prism [@aakanksha2024multilingual],\n", "Multilingual Vulnerabilities [@tang2025multilingual],\n", "OR-Bench [@cui2024orbench],\n", @@ -248,6 +249,9 @@ } ], "metadata": { + "jupytext": { + "main_language": "python" + }, "language_info": { "codemirror_mode": { "name": "ipython", diff --git a/doc/code/datasets/1_loading_datasets.py b/doc/code/datasets/1_loading_datasets.py index cec5261bb..9bcc8e705 100644 --- a/doc/code/datasets/1_loading_datasets.py +++ b/doc/code/datasets/1_loading_datasets.py @@ -32,6 +32,7 @@ # JailbreakBench [@chao2024jailbreakbench], # LLM-LAT [@sheshadri2024lat], # MedSafetyBench [@han2024medsafetybench], +# MOSSBench [@li2024mossbench], # Multilingual Alignment Prism [@aakanksha2024multilingual], # Multilingual Vulnerabilities [@tang2025multilingual], # OR-Bench [@cui2024orbench], diff --git a/doc/references.bib b/doc/references.bib index d94d0c842..7338f21f3 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -610,3 +610,11 @@ @misc{embracethered2025sneakybits url = {https://embracethered.com/blog/posts/2025/sneaky-bits-and-ascii-smuggler/}, note = {Embrace The Red Blog}, } + +@article{li2024mossbench, + title = {{MOSSBench}: Is Your Multimodal Language Model Oversensitive to Safe Queries?}, + author = {Xirui Li and Hengguang Zhou and Ruochen Wang and Tianyi Zhou and Minhao Cheng and Cho-Jui Hsieh}, + journal = {arXiv preprint arXiv:2406.17806}, + year = {2024}, + url = {https://arxiv.org/abs/2406.17806}, +} diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index 284afac5a..ba9f3ccdd 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -73,6 +73,10 @@ from pyrit.datasets.seed_datasets.remote.mlcommons_ailuminate_dataset import ( _MLCommonsAILuminateDataset, ) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.mossbench_dataset import ( + MossBenchOversensitivityType, + _MossBenchDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.msts_dataset import ( _MSTSDataset, ) # noqa: F401 @@ -143,6 +147,7 @@ __all__ = [ "HiXSTestLanguage", + "MossBenchOversensitivityType", "PromptIntelCategory", "PromptIntelSeverity", "SGXSTestLabel", @@ -172,6 +177,7 @@ "_LLMLatentAdversarialTrainingDataset", "_MedSafetyBenchDataset", "_MLCommonsAILuminateDataset", + "_MossBenchDataset", "_MSTSDataset", "_MultilingualVulnerabilityDataset", "_ORBench80KDataset", diff --git a/pyrit/datasets/seed_datasets/remote/mossbench_dataset.py b/pyrit/datasets/seed_datasets/remote/mossbench_dataset.py new file mode 100644 index 000000000..4b956b29c --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/mossbench_dataset.py @@ -0,0 +1,441 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import uuid +from collections.abc import Iterable +from enum import Enum +from typing import Any, Literal, Optional, cast + +from pyrit.datasets.seed_datasets.remote._image_cache import ( + fetch_and_cache_image_async, +) +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.models import SeedDataset, SeedPrompt + +logger = logging.getLogger(__name__) + + +class MossBenchOversensitivityType(Enum): + """ + Oversensitivity stimulus types in the MOSSBench dataset. + + The MOSSBench paper organizes benign-but-tricky multimodal queries around + three cognitive-science triggers that lead VLMs to refuse harmless requests: + + - ``EXAGGERATED_RISK`` — innocuous scene with a visually salient but + contextually harmless risky-looking element (e.g. a toy knife on a + playroom shelf). + - ``NEGATED_HARM`` — image depicts harm in a way that is explicitly negated + by surrounding context. + - ``COUNTERINTUITIVE_INTERPRETATION`` — image whose obvious literal reading + conflicts with the actual user intent. + + The raw ``metadata.over`` field in the upstream JSON uses the strings + ``"type 1"`` / ``"type 2"`` / ``"type 3"`` for these three concepts; + the loader maps those to the enum members. + """ + + EXAGGERATED_RISK = "exaggerated_risk" + NEGATED_HARM = "negated_harm" + COUNTERINTUITIVE_INTERPRETATION = "counterintuitive_interpretation" + + +# Mapping from the raw ``metadata.over`` string in MOSSBench's information.json +# to the loader's public enum. +_RAW_OVERSENSITIVITY_TO_ENUM: dict[str, MossBenchOversensitivityType] = { + "type 1": MossBenchOversensitivityType.EXAGGERATED_RISK, + "type 2": MossBenchOversensitivityType.NEGATED_HARM, + "type 3": MossBenchOversensitivityType.COUNTERINTUITIVE_INTERPRETATION, +} + +# Human-readable label per oversensitivity type (matches the paper). +_OVERSENSITIVITY_TYPE_LABELS: dict[MossBenchOversensitivityType, str] = { + MossBenchOversensitivityType.EXAGGERATED_RISK: "Exaggerated Risk", + MossBenchOversensitivityType.NEGATED_HARM: "Negated Harm", + MossBenchOversensitivityType.COUNTERINTUITIVE_INTERPRETATION: "Counterintuitive Interpretation", +} + + +class _MossBenchDataset(_RemoteDatasetLoader): + """ + Loader for the MOSSBench multimodal over-sensitivity dataset. + + MOSSBench probes whether vision-language models incorrectly **refuse** + harmless multimodal queries when the image contains a superficially + risky-looking element. Unlike most multimodal safety benchmarks, every one + of the 300 manually curated queries is **benign by design**: a well-behaved + VLM should answer normally, not refuse. + + Each example is delivered as a two-piece multimodal user turn — an image + and a benign text question — both wrapped in ``SeedPrompt`` objects + that share a ``prompt_group_id`` and ``sequence=0``. No ``SeedObjective`` + is created (matching the convention used by the text-only over-refusal + loaders ``_XSTestDataset`` and ``_ORBenchBaseDataset``); the + "non-refusal expected" semantics live in the dataset's identity rather + than in a per-row objective field. + + Notes: + - **Image source**: GitHub raw URLs pinned to commit + ``8d68b0614b39d8990a508e03d99975832f399db2``. All 300 PNG images are + in the repo at predictable paths. The first call downloads each image + once via ``fetch_and_cache_image_async``; subsequent calls reuse + the on-disk cache. + - **Harm indices**: the upstream ``metadata.harm`` list contains + HarmBench-style integer indices, but the explicit + ``index → category-label`` mapping is **not** documented in the + paper, repo, or project page. To avoid mislabeling, this loader + preserves the raw integers in ``SeedPrompt.metadata['harm_indices']`` + and uses the over-sensitivity-type slug as the ``harm_categories`` + entry. + - **Image-attribute filters** (``human`` / ``child`` / ``syn`` / ``ocr`` + flags) are surfaced as per-seed metadata only; the loader does not + expose them as constructor filters to keep the API small. + + References: + - Paper: [@li2024mossbench] + - GitHub: + - HF mirror: + + License: CC BY-SA 4.0. + """ + + _COMMIT: str = "8d68b0614b39d8990a508e03d99975832f399db2" + METADATA_URL: str = ( + f"https://raw.githubusercontent.com/xirui-li/MOSSBench/{_COMMIT}/data/images_information/information.json" + ) + IMAGE_BASE_URL: str = f"https://raw.githubusercontent.com/xirui-li/MOSSBench/{_COMMIT}/data/images/" + PAPER_URL: str = "https://arxiv.org/abs/2406.17806" + DESCRIPTION: str = ( + "MOSSBench is a multimodal over-sensitivity benchmark of 300 manually curated benign image+question " + "pairs across three oversensitivity stimulus types (Exaggerated Risk, Negated Harm, Counterintuitive " + "Interpretation). A well-behaved vision-language model should answer each query normally; refusing " + "indicates over-sensitivity. The prompts are benign — the harm-category indices in metadata describe " + "what the image superficially evokes, not the actual harm of the question." + ) + AUTHORS: tuple[str, ...] = ( + "Xirui Li", + "Hengguang Zhou", + "Ruochen Wang", + "Tianyi Zhou", + "Minhao Cheng", + "Cho-Jui Hsieh", + ) + + tags: frozenset[str] = frozenset({"default", "safety", "multimodal", "over_refusal"}) + size: str = "medium" + modalities: tuple[str, ...] = ("text", "image") + harm_categories: tuple[str, ...] = tuple(t.value for t in MossBenchOversensitivityType) + + def __init__( + self, + *, + source: str = METADATA_URL, + source_type: Literal["public_url", "file"] = "public_url", + oversensitivity_types: Optional[list[MossBenchOversensitivityType]] = None, + ) -> None: + """ + Initialize the MOSSBench dataset loader. + + Args: + source (str): URL or file path to the MOSSBench ``information.json`` + metadata file. Defaults to the official GitHub repository at a + pinned commit. + source_type (Literal["public_url", "file"]): The type of source + (``"public_url"`` or ``"file"``). + oversensitivity_types (list[MossBenchOversensitivityType] | None): + Filter examples by oversensitivity stimulus type. If ``None`` + (default), all three types are included. Valid values: + ``MossBenchOversensitivityType.EXAGGERATED_RISK``, + ``MossBenchOversensitivityType.NEGATED_HARM``, + ``MossBenchOversensitivityType.COUNTERINTUITIVE_INTERPRETATION``. + + Raises: + ValueError: If any value in ``oversensitivity_types`` is not a + ``MossBenchOversensitivityType`` member. + """ + self.source = source + self.source_type: Literal["public_url", "file"] = source_type + self.oversensitivity_types = oversensitivity_types + + if oversensitivity_types is not None: + self._validate_enums( + oversensitivity_types, + MossBenchOversensitivityType, + "oversensitivity type", + ) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "mossbench" + + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch MOSSBench examples and return them as a ``SeedDataset``. + + Each example yields two ``SeedPrompt`` objects — an image and a + benign text question — that share a ``prompt_group_id`` and + ``sequence=0`` so the orchestrator delivers them as a single + multimodal user turn. + + Args: + cache (bool): Whether to cache the fetched metadata. Defaults to + ``True``. + + Returns: + SeedDataset: A ``SeedDataset`` containing the multimodal + examples. + + Raises: + ValueError: If any example is missing required keys. + """ + logger.info(f"Loading MOSSBench dataset from {self.source}") + + examples = self._load_examples(cache=cache) + prompts: list[SeedPrompt] = [] + failed_image_count = 0 + + for example in examples: + pid, question, image_filename = self._extract_required_fields(example) + oversensitivity_type = self._parse_oversensitivity_type(example) + if not self._matches_filters(oversensitivity_type): + continue + + try: + pair = await self._build_prompt_pair_async( + pid=pid, + question=question, + image_filename=image_filename, + example=example, + oversensitivity_type=oversensitivity_type, + ) + except Exception as e: + failed_image_count += 1 + logger.warning(f"[MOSSBench] Failed to fetch image for pid={pid}: {e}. Skipping this example.") + continue + + prompts.extend(pair) + + if failed_image_count > 0: + logger.warning(f"[MOSSBench] Skipped {failed_image_count} example(s) due to image fetch failures") + + logger.info(f"Successfully loaded {len(prompts)} prompts from MOSSBench dataset") + return SeedDataset(seeds=prompts, dataset_name=self.dataset_name) + + def _load_examples(self, *, cache: bool) -> Iterable[dict[str, Any]]: + """ + Fetch the raw MOSSBench ``information.json`` and yield example dicts. + + The upstream JSON is a dict keyed by ``pid`` (1..300); this helper + normalizes the structure to an iterable of value dicts so the rest of + the loader can be source-agnostic. + + Args: + cache (bool): Whether to cache the fetched metadata locally. + + Returns: + Iterable[dict[str, Any]]: Iterable over per-pid example dicts. + + Raises: + ValueError: If the parsed JSON is not a dict of pid → entry. + """ + raw = cast( + "Any", + self._fetch_from_url(source=self.source, source_type=self.source_type, cache=cache), + ) + if not isinstance(raw, dict): + raise ValueError( + f"Expected MOSSBench information.json to be a dict keyed by pid, got {type(raw).__name__}." + ) + return cast("Iterable[dict[str, Any]]", raw.values()) + + def _matches_filters(self, oversensitivity_type: MossBenchOversensitivityType) -> bool: + """ + Return whether an example passes the configured oversensitivity-type filter. + + Args: + oversensitivity_type (MossBenchOversensitivityType): Parsed oversensitivity + type for the candidate example. + + Returns: + bool: ``True`` if the example should be included. + """ + if self.oversensitivity_types is None: + return True + return oversensitivity_type in self.oversensitivity_types + + async def _build_prompt_pair_async( + self, + *, + pid: str, + question: str, + image_filename: str, + example: dict[str, Any], + oversensitivity_type: MossBenchOversensitivityType, + ) -> list[SeedPrompt]: + """ + Build an image+text ``SeedPrompt`` pair for a single MOSSBench example. + + Args: + pid (str): MOSSBench prompt id (used in seed names and cached image + filename). + question (str): Benign text question for the example. + image_filename (str): Basename of the upstream image (e.g. + ``"42.png"``); joined onto ``IMAGE_BASE_URL`` to form the URL. + example (dict[str, Any]): Single example dict from the upstream + ``information.json`` (used to extract attribute-flag metadata). + oversensitivity_type (MossBenchOversensitivityType): Parsed + oversensitivity type for the example. + + Returns: + list[SeedPrompt]: A two-element list — the image prompt followed by + the text prompt — both sharing ``prompt_group_id`` and + ``sequence=0``. + + Raises: + Exception: If the image cannot be fetched. + """ + meta = self._extract_metadata(example=example, oversensitivity_type=oversensitivity_type) + group_id = uuid.uuid4() + image_url = f"{self.IMAGE_BASE_URL}{image_filename}" + + local_image_path = await self._fetch_and_save_image_async(image_url=image_url, pid=pid) + + image_prompt = SeedPrompt( + value=local_image_path, + data_type="image_path", + name=f"MOSSBench Image - {pid}", + dataset_name=self.dataset_name, + harm_categories=[oversensitivity_type.value], + description=self.DESCRIPTION, + authors=list(self.AUTHORS), + source=self.PAPER_URL, + prompt_group_id=group_id, + sequence=0, + metadata={**meta, "original_image_url": image_url}, + ) + + text_prompt = SeedPrompt( + value=question, + data_type="text", + name=f"MOSSBench Text - {pid}", + dataset_name=self.dataset_name, + harm_categories=[oversensitivity_type.value], + description=self.DESCRIPTION, + authors=list(self.AUTHORS), + source=self.PAPER_URL, + prompt_group_id=group_id, + sequence=0, + metadata=meta, + ) + + return [image_prompt, text_prompt] + + @staticmethod + def _extract_required_fields(example: dict[str, Any]) -> tuple[str, str, str]: + """ + Pull ``pid``, ``question``, and ``image`` filename from a raw example. + + Args: + example (dict[str, Any]): Single example dict from the upstream + ``information.json``. + + Returns: + tuple[str, str, str]: ``(pid, question, image_filename)`` where + ``image_filename`` is the basename (e.g. ``"42.png"``). + + Raises: + ValueError: If any of the required keys is missing. + """ + required = {"pid", "question", "image"} + missing = required - example.keys() + if missing: + raise ValueError(f"Missing keys in MOSSBench example: {', '.join(sorted(missing))}") + + pid = str(example["pid"]) + question = str(example["question"]) + # ``image`` looks like "images/42.png"; we only need the basename so we + # can join it onto IMAGE_BASE_URL ourselves. + image_filename = str(example["image"]).rsplit("/", 1)[-1] + return pid, question, image_filename + + @staticmethod + def _parse_oversensitivity_type(example: dict[str, Any]) -> MossBenchOversensitivityType: + """ + Map the raw ``metadata.over`` string to a ``MossBenchOversensitivityType``. + + Args: + example (dict[str, Any]): Single example dict from the upstream + ``information.json``. + + Returns: + MossBenchOversensitivityType: The parsed oversensitivity type. + + Raises: + ValueError: If ``metadata.over`` is missing or not one of the three + known values. + """ + meta = example.get("metadata") or {} + raw_over = meta.get("over") + if raw_over not in _RAW_OVERSENSITIVITY_TO_ENUM: + valid = ", ".join(sorted(_RAW_OVERSENSITIVITY_TO_ENUM)) + raise ValueError( + f"MOSSBench example pid={example.get('pid', '?')} has unknown over type " + f"{raw_over!r}; expected one of: {valid}." + ) + return _RAW_OVERSENSITIVITY_TO_ENUM[raw_over] + + @staticmethod + def _extract_metadata( + *, + example: dict[str, Any], + oversensitivity_type: MossBenchOversensitivityType, + ) -> dict[str, Any]: + """ + Build the per-seed metadata dict, preserving all upstream attribute flags. + + Args: + example (dict[str, Any]): Single example dict from the upstream + ``information.json``. + oversensitivity_type (MossBenchOversensitivityType): Parsed + oversensitivity type for the example. + + Returns: + dict[str, Any]: Metadata dict including ``pid``, the + oversensitivity-type slug + label, and the raw image-attribute flags + (``human``, ``child``, ``syn``, ``ocr``, ``harm_indices``). + """ + meta = example.get("metadata") or {} + harm_raw = meta.get("harm") or [] + harm_indices = [int(h) for h in harm_raw if isinstance(h, (int, float, str)) and str(h).lstrip("-").isdigit()] + return { + "pid": str(example["pid"]), + "oversensitivity_type": oversensitivity_type.value, + "oversensitivity_type_label": _OVERSENSITIVITY_TYPE_LABELS[oversensitivity_type], + "human": bool(meta.get("human", 0)), + "child": bool(meta.get("child", 0)), + "syn": bool(meta.get("syn", 0)), + "ocr": bool(meta.get("ocr", 0)), + "harm_indices": harm_indices, + "short_description": str(example.get("short description", "")), + } + + async def _fetch_and_save_image_async(self, *, image_url: str, pid: str) -> str: + """ + Fetch and cache a MOSSBench image. + + Args: + image_url (str): URL to the image PNG. + pid (str): MOSSBench prompt id, used to name the cached file. + + Returns: + str: Local path to the cached image. + """ + return await fetch_and_cache_image_async( + filename=f"mossbench_{pid}.png", + image_url=image_url, + log_prefix="MOSSBench", + ) diff --git a/tests/unit/datasets/test_mossbench_dataset.py b/tests/unit/datasets/test_mossbench_dataset.py new file mode 100644 index 000000000..585d41ab2 --- /dev/null +++ b/tests/unit/datasets/test_mossbench_dataset.py @@ -0,0 +1,305 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Any +from unittest.mock import patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.mossbench_dataset import ( + MossBenchOversensitivityType, + _MossBenchDataset, +) +from pyrit.memory import SQLiteMemory +from pyrit.memory.central_memory import CentralMemory +from pyrit.models import SeedDataset + + +def _make_example( + *, + pid: int, + over: str, + question: str = "Describe a fun game a child can play with these toys.", + human: int = 0, + child: int = 1, + syn: int = 1, + ocr: int = 0, + harm: list[int] | None = None, +) -> dict[str, Any]: + return { + "image": f"images/{pid}.png", + "short description": f"Short description {pid}.", + "question": question, + "metadata": { + "over": over, + "human": human, + "child": child, + "syn": syn, + "ocr": ocr, + "harm": harm if harm is not None else [7], + }, + "pid": str(pid), + "description": f"Long description {pid}.", + } + + +def _make_information_json(examples: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + """Mirror the upstream pid-keyed dict format produced by ``information.json``.""" + return {ex["pid"]: ex for ex in examples} + + +class TestMossBenchDataset: + """Unit tests for ``_MossBenchDataset``.""" + + @pytest.fixture(autouse=True) + def setup_memory(self): + """Set up memory instance for image downloads.""" + memory = SQLiteMemory() + CentralMemory.set_memory_instance(memory) + yield + CentralMemory.set_memory_instance(None) + + def test_dataset_name(self): + dataset = _MossBenchDataset() + assert dataset.dataset_name == "mossbench" + + def test_init_defaults(self): + dataset = _MossBenchDataset() + assert dataset.oversensitivity_types is None + assert dataset.source_type == "public_url" + assert "/MOSSBench/" in dataset.source + assert "information.json" in dataset.source + + def test_init_with_oversensitivity_types(self): + types = [ + MossBenchOversensitivityType.EXAGGERATED_RISK, + MossBenchOversensitivityType.NEGATED_HARM, + ] + dataset = _MossBenchDataset(oversensitivity_types=types) + assert dataset.oversensitivity_types == types + + def test_init_with_invalid_oversensitivity_types_raises(self): + with pytest.raises(ValueError, match="Expected MossBenchOversensitivityType"): + _MossBenchDataset(oversensitivity_types=["exaggerated_risk"]) # type: ignore[list-item] + + def test_init_rejects_raw_string_matching_enum_value(self): + with pytest.raises(ValueError, match="Expected MossBenchOversensitivityType"): + _MossBenchDataset(oversensitivity_types=["type 1"]) # type: ignore[list-item] + + def test_dataset_level_metadata(self): + dataset = _MossBenchDataset() + assert "over_refusal" in dataset.tags + assert "multimodal" in dataset.tags + assert "safety" in dataset.tags + assert dataset.size == "medium" + assert set(dataset.modalities) == {"text", "image"} + assert set(dataset.harm_categories) == { + "exaggerated_risk", + "negated_harm", + "counterintuitive_interpretation", + } + + async def test_fetch_dataset_happy_path(self): + examples = [ + _make_example(pid=1, over="type 1", question="Q1", harm=[7]), + _make_example(pid=2, over="type 2", question="Q2", harm=[5, 6]), + _make_example(pid=3, over="type 3", question="Q3", harm=[1]), + ] + mock_data = _make_information_json(examples) + dataset_loader = _MossBenchDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + side_effect=lambda *, image_url, pid: f"/fake/path/{pid}.png", + ), + ): + dataset = await dataset_loader.fetch_dataset_async(cache=False) + + assert isinstance(dataset, SeedDataset) + assert len(dataset.seeds) == 6 # 3 examples * 2 seeds each + + text_prompts = [s for s in dataset.seeds if s.data_type == "text"] + image_prompts = [s for s in dataset.seeds if s.data_type == "image_path"] + assert len(text_prompts) == 3 + assert len(image_prompts) == 3 + + # Verify image+text pairs share group_id and sequence=0 + text_by_pid = {p.metadata["pid"]: p for p in text_prompts} + image_by_pid = {p.metadata["pid"]: p for p in image_prompts} + for pid in ("1", "2", "3"): + text_prompt = text_by_pid[pid] + image_prompt = image_by_pid[pid] + assert text_prompt.prompt_group_id == image_prompt.prompt_group_id + assert text_prompt.sequence == 0 + assert image_prompt.sequence == 0 + + # Different examples have different group_ids + group_ids = {p.prompt_group_id for p in dataset.seeds} + assert len(group_ids) == 3 + + async def test_fetch_dataset_metadata_and_objective_semantics(self): + examples = [_make_example(pid=1, over="type 1", question="Q1", harm=[7, 6])] + mock_data = _make_information_json(examples) + dataset_loader = _MossBenchDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/1.png", + ), + ): + dataset = await dataset_loader.fetch_dataset_async(cache=False) + + text_prompt = next(s for s in dataset.seeds if s.data_type == "text") + image_prompt = next(s for s in dataset.seeds if s.data_type == "image_path") + + # Text prompt holds the literal benign question (mirrors XSTest/OR-Bench). + assert text_prompt.value == "Q1" + assert image_prompt.value == "/fake/path/1.png" + assert text_prompt.dataset_name == "mossbench" + assert image_prompt.dataset_name == "mossbench" + assert text_prompt.harm_categories == ["exaggerated_risk"] + assert image_prompt.harm_categories == ["exaggerated_risk"] + + # Metadata preserves over-type, attribute flags, harm indices, pid. + for prompt in (text_prompt, image_prompt): + assert prompt.metadata["pid"] == "1" + assert prompt.metadata["oversensitivity_type"] == "exaggerated_risk" + assert prompt.metadata["oversensitivity_type_label"] == "Exaggerated Risk" + assert prompt.metadata["human"] is False + assert prompt.metadata["child"] is True + assert prompt.metadata["syn"] is True + assert prompt.metadata["ocr"] is False + assert prompt.metadata["harm_indices"] == [7, 6] + assert "Short description 1." in prompt.metadata["short_description"] + + # Image prompt also carries the original URL for traceability. + assert image_prompt.metadata["original_image_url"].endswith("/data/images/1.png") + assert "original_image_url" not in text_prompt.metadata + + async def test_fetch_dataset_filters_by_type(self): + examples = [ + _make_example(pid=1, over="type 1"), + _make_example(pid=2, over="type 2"), + _make_example(pid=3, over="type 3"), + ] + mock_data = _make_information_json(examples) + dataset_loader = _MossBenchDataset(oversensitivity_types=[MossBenchOversensitivityType.NEGATED_HARM]) + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/2.png", + ), + ): + dataset = await dataset_loader.fetch_dataset_async(cache=False) + + assert len(dataset.seeds) == 2 # one pair only + pids = {s.metadata["pid"] for s in dataset.seeds} + assert pids == {"2"} + for seed in dataset.seeds: + assert seed.harm_categories == ["negated_harm"] + + async def test_fetch_dataset_filters_by_multiple_types(self): + examples = [ + _make_example(pid=1, over="type 1"), + _make_example(pid=2, over="type 2"), + _make_example(pid=3, over="type 3"), + ] + mock_data = _make_information_json(examples) + dataset_loader = _MossBenchDataset( + oversensitivity_types=[ + MossBenchOversensitivityType.EXAGGERATED_RISK, + MossBenchOversensitivityType.COUNTERINTUITIVE_INTERPRETATION, + ] + ) + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + side_effect=lambda *, image_url, pid: f"/fake/path/{pid}.png", + ), + ): + dataset = await dataset_loader.fetch_dataset_async(cache=False) + + pids = {s.metadata["pid"] for s in dataset.seeds} + assert pids == {"1", "3"} + + async def test_fetch_dataset_missing_required_key_raises(self): + bad_example = { + "image": "images/1.png", + "metadata": {"over": "type 1"}, + "pid": "1", + # missing "question" + } + mock_data = {"1": bad_example} + dataset_loader = _MossBenchDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/1.png", + ), + pytest.raises(ValueError, match="Missing keys in MOSSBench example"), + ): + await dataset_loader.fetch_dataset_async(cache=False) + + async def test_fetch_dataset_unknown_oversensitivity_type_raises(self): + bad_example = _make_example(pid=1, over="type 99") + mock_data = _make_information_json([bad_example]) + dataset_loader = _MossBenchDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/1.png", + ), + pytest.raises(ValueError, match="unknown over type"), + ): + await dataset_loader.fetch_dataset_async(cache=False) + + async def test_fetch_dataset_skips_failed_image_download(self): + examples = [ + _make_example(pid=1, over="type 1"), + _make_example(pid=2, over="type 2"), + ] + mock_data = _make_information_json(examples) + dataset_loader = _MossBenchDataset() + + async def flaky_fetch(*, image_url: str, pid: str) -> str: + if pid == "1": + raise RuntimeError("network error") + return f"/fake/path/{pid}.png" + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object(dataset_loader, "_fetch_and_save_image_async", side_effect=flaky_fetch), + ): + dataset = await dataset_loader.fetch_dataset_async(cache=False) + + # pid=1 failed and was skipped; pid=2 succeeded with both seeds. + assert len(dataset.seeds) == 2 + pids = {s.metadata["pid"] for s in dataset.seeds} + assert pids == {"2"} + + async def test_fetch_dataset_rejects_non_dict_json(self): + dataset_loader = _MossBenchDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=[{"foo": "bar"}]), + pytest.raises(ValueError, match="dict keyed by pid"), + ): + await dataset_loader.fetch_dataset_async(cache=False)