diff --git a/doc/code/datasets/1_loading_datasets.ipynb b/doc/code/datasets/1_loading_datasets.ipynb index 4dc1c6ab9..09d10c884 100644 --- a/doc/code/datasets/1_loading_datasets.ipynb +++ b/doc/code/datasets/1_loading_datasets.ipynb @@ -73,10 +73,20 @@ " 'airt_scams',\n", " 'airt_sexual',\n", " 'airt_violence',\n", - " 'aya_redteaming',\n", - " 'babelscape_alert',\n", + " 'aya_redteaming_arabic',\n", + " 'aya_redteaming_english',\n", + " 'aya_redteaming_french',\n", + " 'aya_redteaming_hindi',\n", + " 'aya_redteaming_russian',\n", + " 'aya_redteaming_serbian',\n", + " 'aya_redteaming_spanish',\n", + " 'aya_redteaming_tagalog',\n", + " 'babelscape_alert_adversarial',\n", + " 'babelscape_alert_original',\n", " 'beaver_tails',\n", - " 'categorical_harmful_qa',\n", + " 'categorical_harmful_qa_chinese',\n", + " 'categorical_harmful_qa_english',\n", + " 'categorical_harmful_qa_vietnamese',\n", " 'cbt_bench',\n", " 'ccp_sensitive_prompts',\n", " 'comic_jailbreak',\n", @@ -90,7 +100,8 @@ " 'harmbench',\n", " 'harmbench_multimodal',\n", " 'harmful_qa',\n", - " 'hixstest',\n", + " 'hixstest_english',\n", + " 'hixstest_hindi',\n", " 'jbb_behaviors',\n", " 'librai_do_not_answer',\n", " 'llm_lat_harmful',\n", @@ -99,6 +110,16 @@ " 'ml_vlsu',\n", " 'mlcommons_ailuminate',\n", " 'msts',\n", + " 'msts_arabic',\n", + " 'msts_chinese',\n", + " 'msts_farsi',\n", + " 'msts_french',\n", + " 'msts_german',\n", + " 'msts_hindi',\n", + " 'msts_italian',\n", + " 'msts_korean',\n", + " 'msts_russian',\n", + " 'msts_spanish',\n", " 'multilingual_vulnerability',\n", " 'or_bench_80k',\n", " 'or_bench_hard',\n", @@ -108,7 +129,9 @@ " 'psfuzz_steal_system_prompt',\n", " 'pyrit_example_dataset',\n", " 'red_team_social_bias',\n", - " 'salad_bench',\n", + " 'salad_bench_attack_enhanced',\n", + " 'salad_bench_base',\n", + " 'salad_bench_defense_enhanced',\n", " 'sgxstest',\n", " 'simple_safety_tests',\n", " 'sorry_bench',\n", @@ -117,7 +140,9 @@ " 'toxic_chat',\n", " 'transphobia_awareness',\n", " 'visual_leak_bench',\n", - " 'vlguard',\n", + " 'vlguard_safe_safes',\n", + " 'vlguard_safe_unsafes',\n", + " 'vlguard_unsafes',\n", " 'xstest']" ] }, diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index 284afac5a..d2dd331e8 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -11,16 +11,28 @@ _AegisContentSafetyDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.aya_redteaming_dataset import ( + AyaHarmCategory, + AyaLanguage, + _AyaRedteamingArabicDataset, _AyaRedteamingDataset, + _AyaRedteamingFrenchDataset, + _AyaRedteamingHindiDataset, + _AyaRedteamingRussianDataset, + _AyaRedteamingSerbianDataset, + _AyaRedteamingSpanishDataset, + _AyaRedteamingTagalogDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import ( _BabelscapeAlertDataset, + _BabelscapeAlertOriginalDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.beaver_tails_dataset import ( _BeaverTailsDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.categorical_harmful_qa_dataset import ( + _CategoricalHarmfulQAChineseDataset, _CategoricalHarmfulQADataset, + _CategoricalHarmfulQAVietnameseDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.cbt_bench_dataset import ( _CBTBenchDataset, @@ -57,6 +69,7 @@ from pyrit.datasets.seed_datasets.remote.hixstest_dataset import ( HiXSTestLanguage, _HiXSTestDataset, + _HiXSTestEnglishDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.jbb_behaviors_dataset import ( _JBBBehaviorsDataset, @@ -74,7 +87,17 @@ _MLCommonsAILuminateDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.msts_dataset import ( + _MSTSArabicDataset, + _MSTSChineseDataset, _MSTSDataset, + _MSTSFarsiDataset, + _MSTSFrenchDataset, + _MSTSGermanDataset, + _MSTSHindiDataset, + _MSTSItalianDataset, + _MSTSKoreanDataset, + _MSTSRussianDataset, + _MSTSSpanishDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.multilingual_vulnerability_dataset import ( # noqa: F401 _MultilingualVulnerabilityDataset, @@ -99,7 +122,9 @@ _RemoteDatasetLoader, ) from pyrit.datasets.seed_datasets.remote.salad_bench_dataset import ( + _SaladBenchAttackEnhancedDataset, _SaladBenchDataset, + _SaladBenchDefenseEnhancedDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.sgxstest_dataset import ( SGXSTestLabel, @@ -133,6 +158,8 @@ VLGuardSubcategory, VLGuardSubset, _VLGuardDataset, + _VLGuardSafeSafesDataset, + _VLGuardSafeUnsafesDataset, ) # noqa: F401 from pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset import ( _VLSUMultimodalDataset, @@ -142,6 +169,8 @@ ) # noqa: F401 __all__ = [ + "AyaHarmCategory", + "AyaLanguage", "HiXSTestLanguage", "PromptIntelCategory", "PromptIntelSeverity", @@ -150,12 +179,22 @@ "VLGuardSubcategory", "VLGuardSubset", "_AegisContentSafetyDataset", + "_AyaRedteamingArabicDataset", "_AyaRedteamingDataset", + "_AyaRedteamingFrenchDataset", + "_AyaRedteamingHindiDataset", + "_AyaRedteamingRussianDataset", + "_AyaRedteamingSerbianDataset", + "_AyaRedteamingSpanishDataset", + "_AyaRedteamingTagalogDataset", "_BabelscapeAlertDataset", + "_BabelscapeAlertOriginalDataset", "_BeaverTailsDataset", "_CBTBenchDataset", "_CCPSensitivePromptsDataset", + "_CategoricalHarmfulQAChineseDataset", "_CategoricalHarmfulQADataset", + "_CategoricalHarmfulQAVietnameseDataset", "_ComicJailbreakDataset", "COMIC_JAILBREAK_TEMPLATES", "ComicJailbreakTemplateConfig", @@ -167,12 +206,23 @@ "_HarmBenchMultimodalDataset", "_HarmfulQADataset", "_HiXSTestDataset", + "_HiXSTestEnglishDataset", "_JBBBehaviorsDataset", "_LibrAIDoNotAnswerDataset", "_LLMLatentAdversarialTrainingDataset", "_MedSafetyBenchDataset", "_MLCommonsAILuminateDataset", + "_MSTSArabicDataset", + "_MSTSChineseDataset", "_MSTSDataset", + "_MSTSFarsiDataset", + "_MSTSFrenchDataset", + "_MSTSGermanDataset", + "_MSTSHindiDataset", + "_MSTSItalianDataset", + "_MSTSKoreanDataset", + "_MSTSRussianDataset", + "_MSTSSpanishDataset", "_MultilingualVulnerabilityDataset", "_ORBench80KDataset", "_ORBenchHardDataset", @@ -182,7 +232,9 @@ "_RedTeamSocialBiasDataset", "_RemoteDatasetLoader", "_SGXSTestDataset", + "_SaladBenchAttackEnhancedDataset", "_SaladBenchDataset", + "_SaladBenchDefenseEnhancedDataset", "_SimpleSafetyTestsDataset", "_SOSBenchDataset", "_SorryBenchDataset", @@ -190,6 +242,8 @@ "_ToxicChatDataset", "_TransphobiaAwarenessDataset", "_VLGuardDataset", + "_VLGuardSafeSafesDataset", + "_VLGuardSafeUnsafesDataset", "_VLSUMultimodalDataset", "_VisualLeakBenchDataset", "VisualLeakBenchCategory", diff --git a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py index 3056c35e2..5e2bb859f 100644 --- a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py @@ -12,6 +12,20 @@ logger = logging.getLogger(__name__) +AyaLanguage = Literal["English", "Hindi", "French", "Spanish", "Arabic", "Russian", "Serbian", "Tagalog"] + +AyaHarmCategory = Literal[ + "Bullying & Harassment", + "Discrimination & Injustice", + "Graphic material", + "Harms of Representation Allocation and Quality of Service", + "Hate Speech", + "Non-consensual sexual content", + "Profanity", + "Self-Harm", + "Violence, Threats & Incitement", +] + class _AyaRedteamingDataset(_RemoteDatasetLoader): """ @@ -42,24 +56,8 @@ class _AyaRedteamingDataset(_RemoteDatasetLoader): def __init__( self, *, - language: Literal[ - "English", "Hindi", "French", "Spanish", "Arabic", "Russian", "Serbian", "Tagalog" - ] = "English", - harm_categories: Optional[ - list[ - Literal[ - "Bullying & Harassment", - "Discrimination & Injustice", - "Graphic material", - "Harms of Representation Allocation and Quality of Service", - "Hate Speech", - "Non-consensual sexual content", - "Profanity", - "Self-Harm", - "Violence, Threats & Incitement", - ] - ] - ] = None, + language: AyaLanguage = "English", + harm_categories: Optional[list[AyaHarmCategory]] = None, harm_scope: Optional[Literal["global", "local"]] = None, ) -> None: """ @@ -82,7 +80,7 @@ def __init__( @property def dataset_name(self) -> str: """Return the dataset name.""" - return "aya_redteaming" + return "aya_redteaming_english" async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ @@ -130,3 +128,122 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: logger.info(f"Successfully loaded {len(seed_prompts)} prompts from Aya Red-teaming dataset") return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + +class _AyaRedteamingHindiDataset(_AyaRedteamingDataset): + """Sibling loader pinned to the Hindi split of the Aya Red-teaming dataset.""" + + def __init__( + self, + *, + harm_categories: Optional[list[AyaHarmCategory]] = None, + harm_scope: Optional[Literal["global", "local"]] = None, + ) -> None: + super().__init__(language="Hindi", harm_categories=harm_categories, harm_scope=harm_scope) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "aya_redteaming_hindi" + + +class _AyaRedteamingFrenchDataset(_AyaRedteamingDataset): + """Sibling loader pinned to the French split of the Aya Red-teaming dataset.""" + + def __init__( + self, + *, + harm_categories: Optional[list[AyaHarmCategory]] = None, + harm_scope: Optional[Literal["global", "local"]] = None, + ) -> None: + super().__init__(language="French", harm_categories=harm_categories, harm_scope=harm_scope) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "aya_redteaming_french" + + +class _AyaRedteamingSpanishDataset(_AyaRedteamingDataset): + """Sibling loader pinned to the Spanish split of the Aya Red-teaming dataset.""" + + def __init__( + self, + *, + harm_categories: Optional[list[AyaHarmCategory]] = None, + harm_scope: Optional[Literal["global", "local"]] = None, + ) -> None: + super().__init__(language="Spanish", harm_categories=harm_categories, harm_scope=harm_scope) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "aya_redteaming_spanish" + + +class _AyaRedteamingArabicDataset(_AyaRedteamingDataset): + """Sibling loader pinned to the Arabic split of the Aya Red-teaming dataset.""" + + def __init__( + self, + *, + harm_categories: Optional[list[AyaHarmCategory]] = None, + harm_scope: Optional[Literal["global", "local"]] = None, + ) -> None: + super().__init__(language="Arabic", harm_categories=harm_categories, harm_scope=harm_scope) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "aya_redteaming_arabic" + + +class _AyaRedteamingRussianDataset(_AyaRedteamingDataset): + """Sibling loader pinned to the Russian split of the Aya Red-teaming dataset.""" + + def __init__( + self, + *, + harm_categories: Optional[list[AyaHarmCategory]] = None, + harm_scope: Optional[Literal["global", "local"]] = None, + ) -> None: + super().__init__(language="Russian", harm_categories=harm_categories, harm_scope=harm_scope) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "aya_redteaming_russian" + + +class _AyaRedteamingSerbianDataset(_AyaRedteamingDataset): + """Sibling loader pinned to the Serbian split of the Aya Red-teaming dataset.""" + + def __init__( + self, + *, + harm_categories: Optional[list[AyaHarmCategory]] = None, + harm_scope: Optional[Literal["global", "local"]] = None, + ) -> None: + super().__init__(language="Serbian", harm_categories=harm_categories, harm_scope=harm_scope) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "aya_redteaming_serbian" + + +class _AyaRedteamingTagalogDataset(_AyaRedteamingDataset): + """Sibling loader pinned to the Tagalog split of the Aya Red-teaming dataset.""" + + def __init__( + self, + *, + harm_categories: Optional[list[AyaHarmCategory]] = None, + harm_scope: Optional[Literal["global", "local"]] = None, + ) -> None: + super().__init__(language="Tagalog", harm_categories=harm_categories, harm_scope=harm_scope) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "aya_redteaming_tagalog" diff --git a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py index 386d4190e..4e452ed6e 100644 --- a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py @@ -49,7 +49,7 @@ def __init__( @property def dataset_name(self) -> str: """Return the dataset name.""" - return "babelscape_alert" + return "babelscape_alert_adversarial" async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ @@ -95,3 +95,31 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: logger.info(f"Successfully loaded {len(seed_prompts)} prompts from Babelscape Alert dataset") return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + +class _BabelscapeAlertOriginalDataset(_BabelscapeAlertDataset): + """ + Sibling loader pinned to the original (non-adversarial) ALERT split. + + Same as :class:`_BabelscapeAlertDataset` but defaults to ``category="alert"`` + so daily e2e tests exercise the 15k base ALERT config in addition to the + adversarial config covered by the parent. + """ + + def __init__( + self, + *, + source: str = "Babelscape/ALERT", + ) -> None: + """ + Initialize the Babelscape ALERT (original) dataset loader. + + Args: + source: HuggingFace dataset identifier. Defaults to "Babelscape/ALERT". + """ + super().__init__(source=source, category="alert") + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "babelscape_alert_original" diff --git a/pyrit/datasets/seed_datasets/remote/categorical_harmful_qa_dataset.py b/pyrit/datasets/seed_datasets/remote/categorical_harmful_qa_dataset.py index c2d1c10d5..ffbaf0339 100644 --- a/pyrit/datasets/seed_datasets/remote/categorical_harmful_qa_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/categorical_harmful_qa_dataset.py @@ -69,7 +69,7 @@ def __init__( @property def dataset_name(self) -> str: """Return the dataset name.""" - return "categorical_harmful_qa" + return "categorical_harmful_qa_english" async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ @@ -129,3 +129,37 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: logger.info(f"Successfully loaded {len(seed_objectives)} objectives from CategoricalHarmfulQA dataset") return SeedDataset(seeds=seed_objectives, dataset_name=self.dataset_name) + + +class _CategoricalHarmfulQAChineseDataset(_CategoricalHarmfulQADataset): + """ + Sibling loader pinned to the Chinese split of CategoricalHarmfulQA. + + Same as :class:`_CategoricalHarmfulQADataset` but defaults to ``language="zh"`` + so daily e2e tests exercise the Chinese split as well as the English default. + """ + + def __init__(self) -> None: + super().__init__(language="zh") + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "categorical_harmful_qa_chinese" + + +class _CategoricalHarmfulQAVietnameseDataset(_CategoricalHarmfulQADataset): + """ + Sibling loader pinned to the Vietnamese split of CategoricalHarmfulQA. + + Same as :class:`_CategoricalHarmfulQADataset` but defaults to ``language="vi"`` + so daily e2e tests exercise the Vietnamese split as well as the English default. + """ + + def __init__(self) -> None: + super().__init__(language="vi") + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "categorical_harmful_qa_vietnamese" diff --git a/pyrit/datasets/seed_datasets/remote/hixstest_dataset.py b/pyrit/datasets/seed_datasets/remote/hixstest_dataset.py index e08cd5331..e6c73777b 100644 --- a/pyrit/datasets/seed_datasets/remote/hixstest_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/hixstest_dataset.py @@ -93,7 +93,7 @@ def __init__( @property def dataset_name(self) -> str: """Return the dataset name.""" - return "hixstest" + return "hixstest_hindi" async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ @@ -187,3 +187,38 @@ def _select_value(self, item: dict) -> str: f"HiXSTest row is missing required field '{key}' for language={self.language.value}: {item!r}" ) return value + + +class _HiXSTestEnglishDataset(_HiXSTestDataset): + """ + Sibling loader pinned to the English-translation column of HiXSTest. + + Same as :class:`_HiXSTestDataset` but defaults to ``language=HiXSTestLanguage.ENGLISH`` + so daily e2e tests exercise the ``english_prompt`` column in addition to the + Hindi ``prompt`` column read by the parent. + + Note: This is the same gated HuggingFace dataset as :class:`_HiXSTestDataset` — + requires a token (``HUGGINGFACE_TOKEN`` env var or explicit kwarg) and accepting + the dataset terms on HuggingFace. + """ + + def __init__( + self, + *, + split: str = "train", + token: str | None = None, + ) -> None: + """ + Initialize the HiXSTest dataset loader pinned to the English column. + + Args: + split: Dataset split to load. Defaults to "train" (the only split). + token: Hugging Face authentication token. If not provided, reads from the + ``HUGGINGFACE_TOKEN`` environment variable. + """ + super().__init__(language=HiXSTestLanguage.ENGLISH, split=split, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "hixstest_english" diff --git a/pyrit/datasets/seed_datasets/remote/msts_dataset.py b/pyrit/datasets/seed_datasets/remote/msts_dataset.py index cc854aa66..934355ade 100644 --- a/pyrit/datasets/seed_datasets/remote/msts_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/msts_dataset.py @@ -494,3 +494,203 @@ def _encode_pil_image(*, pil_image: "PILImage | None", extension: str) -> bytes except Exception as e: logger.warning(f"[MSTS] Failed to encode PIL image as {pil_format}: {e}. Will fall back to URL fetch.") return None + + +class _MSTSGermanDataset(_MSTSDataset): + """Sibling loader pinned to the German split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["de"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_german" + + +class _MSTSRussianDataset(_MSTSDataset): + """Sibling loader pinned to the Russian split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["ru"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_russian" + + +class _MSTSChineseDataset(_MSTSDataset): + """Sibling loader pinned to the Chinese split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["zh"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_chinese" + + +class _MSTSHindiDataset(_MSTSDataset): + """Sibling loader pinned to the Hindi split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["hi"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_hindi" + + +class _MSTSSpanishDataset(_MSTSDataset): + """Sibling loader pinned to the Spanish split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["es"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_spanish" + + +class _MSTSItalianDataset(_MSTSDataset): + """Sibling loader pinned to the Italian split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["it"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_italian" + + +class _MSTSFrenchDataset(_MSTSDataset): + """Sibling loader pinned to the French split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["fr"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_french" + + +class _MSTSKoreanDataset(_MSTSDataset): + """Sibling loader pinned to the Korean split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["ko"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_korean" + + +class _MSTSArabicDataset(_MSTSDataset): + """Sibling loader pinned to the Arabic split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["ar"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_arabic" + + +class _MSTSFarsiDataset(_MSTSDataset): + """Sibling loader pinned to the Farsi split of MSTS.""" + + size: str = "medium" + + def __init__( + self, + *, + text_modifiers: list[str] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__(languages=["fa"], text_modifiers=text_modifiers, max_examples=max_examples, token=token) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "msts_farsi" diff --git a/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py index 413dbd915..9421efd6f 100644 --- a/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py @@ -52,7 +52,7 @@ def __init__( @property def dataset_name(self) -> str: """Return the dataset name.""" - return "salad_bench" + return "salad_bench_base" @staticmethod def _parse_category(category: str) -> str: @@ -129,3 +129,47 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: logger.info(f"Successfully loaded {len(seed_prompts)} prompts from SALAD-Bench dataset") return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + +class _SaladBenchAttackEnhancedDataset(_SaladBenchDataset): + """ + Sibling loader pinned to the SALAD-Bench attack-enhanced split. + + Same as :class:`_SaladBenchDataset` but defaults to ``split="attackEnhanced"`` + so daily e2e tests exercise the attack-enhanced variant in addition to the + base split covered by the parent. + """ + + def __init__( + self, + *, + config: str = "prompts", + ) -> None: + super().__init__(config=config, split="attackEnhanced") + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "salad_bench_attack_enhanced" + + +class _SaladBenchDefenseEnhancedDataset(_SaladBenchDataset): + """ + Sibling loader pinned to the SALAD-Bench defense-enhanced split. + + Same as :class:`_SaladBenchDataset` but defaults to ``split="defenseEnhanced"`` + so daily e2e tests exercise the defense-enhanced variant in addition to the + base split covered by the parent. + """ + + def __init__( + self, + *, + config: str = "prompts", + ) -> None: + super().__init__(config=config, split="defenseEnhanced") + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "salad_bench_defense_enhanced" diff --git a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py index 1ab86cfe7..2893f481d 100644 --- a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py @@ -133,7 +133,7 @@ def __init__( @property def dataset_name(self) -> str: """Return the dataset name.""" - return "vlguard" + return "vlguard_unsafes" async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ @@ -317,3 +317,73 @@ def _download_sync() -> tuple[str, str]: metadata = json.load(f) return metadata, image_dir + + +class _VLGuardSafeUnsafesDataset(_VLGuardDataset): + """ + Sibling loader pinned to the VLGuard ``safe_unsafes`` subset. + + Same as :class:`_VLGuardDataset` but defaults to + ``subset=VLGuardSubset.SAFE_UNSAFES`` so daily e2e tests exercise the + safe-images-with-unsafe-instructions evaluation regime (and the + ``unsafe_instruction`` field read) in addition to the ``UNSAFES`` default + covered by the parent. + + Note: This is the same gated HuggingFace dataset as :class:`_VLGuardDataset` — + requires a token (``HUGGINGFACE_TOKEN`` env var or explicit kwarg) and accepting + the dataset terms on HuggingFace. + """ + + def __init__( + self, + *, + categories: list[VLGuardCategory] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__( + subset=VLGuardSubset.SAFE_UNSAFES, + categories=categories, + max_examples=max_examples, + token=token, + ) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "vlguard_safe_unsafes" + + +class _VLGuardSafeSafesDataset(_VLGuardDataset): + """ + Sibling loader pinned to the VLGuard ``safe_safes`` subset. + + Same as :class:`_VLGuardDataset` but defaults to + ``subset=VLGuardSubset.SAFE_SAFES`` so daily e2e tests exercise the + safe-images-with-safe-instructions helpfulness regime (and the + ``safe_instruction`` field read) in addition to the ``UNSAFES`` default + covered by the parent. + + Note: This is the same gated HuggingFace dataset as :class:`_VLGuardDataset` — + requires a token (``HUGGINGFACE_TOKEN`` env var or explicit kwarg) and accepting + the dataset terms on HuggingFace. + """ + + def __init__( + self, + *, + categories: list[VLGuardCategory] | None = None, + max_examples: int | None = None, + token: str | None = None, + ) -> None: + super().__init__( + subset=VLGuardSubset.SAFE_SAFES, + categories=categories, + max_examples=max_examples, + token=token, + ) + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "vlguard_safe_safes" diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index bf19ecb5c..7960cda5b 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -16,6 +16,20 @@ logger = logging.getLogger(__name__) +# Maps legacy dataset_name strings (from before each loader's variant was made +# explicit in the name) to the new explicit dataset_name. Callers passing a +# legacy name to fetch_datasets_async() get the variant the loader was +# defaulting to plus a DeprecationWarning. Remove the corresponding entry in +# v0.16.0; the entire map can go once all entries have been removed. +_LEGACY_DATASET_NAME_ALIASES: dict[str, str] = { + "aya_redteaming": "aya_redteaming_english", + "babelscape_alert": "babelscape_alert_adversarial", + "categorical_harmful_qa": "categorical_harmful_qa_english", + "hixstest": "hixstest_hindi", + "salad_bench": "salad_bench_base", + "vlguard": "vlguard_unsafes", +} + class SeedDatasetProvider(ABC): """ @@ -273,6 +287,34 @@ def _match_single_criterion( return True + @classmethod + def _resolve_legacy_dataset_names(cls, dataset_names: list[str]) -> list[str]: + """ + Resolve any legacy dataset names to their current canonical names. + + Emits a ``DeprecationWarning`` for each legacy name encountered. + + Args: + dataset_names: Requested dataset names (may include legacy names). + + Returns: + list[str]: The same list with any legacy names replaced by their + current canonical equivalents. + """ + resolved: list[str] = [] + for name in dataset_names: + new_name = _LEGACY_DATASET_NAME_ALIASES.get(name) + if new_name is not None: + warnings.warn( + f"Dataset name '{name}' is deprecated; use '{new_name}' instead. Will be removed in v0.16.0.", + DeprecationWarning, + stacklevel=3, + ) + resolved.append(new_name) + else: + resolved.append(name) + return resolved + @classmethod async def fetch_datasets_async( cls, @@ -312,6 +354,7 @@ async def fetch_datasets_async( """ # Validate dataset names if specified if dataset_names is not None: + dataset_names = cls._resolve_legacy_dataset_names(dataset_names) available_names = await cls.get_all_dataset_names_async() invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: diff --git a/tests/end_to_end/test_all_datasets.py b/tests/end_to_end/test_all_datasets.py index fbc609085..63c38a74c 100644 --- a/tests/end_to_end/test_all_datasets.py +++ b/tests/end_to_end/test_all_datasets.py @@ -22,7 +22,21 @@ from pyrit.datasets import SeedDatasetProvider from pyrit.datasets.seed_datasets.remote import ( _HarmBenchMultimodalDataset, + _MSTSArabicDataset, + _MSTSChineseDataset, + _MSTSDataset, + _MSTSFarsiDataset, + _MSTSFrenchDataset, + _MSTSGermanDataset, + _MSTSHindiDataset, + _MSTSItalianDataset, + _MSTSKoreanDataset, + _MSTSRussianDataset, + _MSTSSpanishDataset, _PromptIntelDataset, + _VLGuardDataset, + _VLGuardSafeSafesDataset, + _VLGuardSafeUnsafesDataset, _VLSUMultimodalDataset, ) from pyrit.models import SeedDataset @@ -40,6 +54,28 @@ # due to rate-limiting, so an empty result is expected in some environments. _IMAGE_FETCHING_PROVIDERS: set[type] = {_HarmBenchMultimodalDataset, _VLSUMultimodalDataset} +# Image-heavy providers where we cap rows in e2e to keep CI bounded. Production +# defaults are unchanged; this is a test-only override applied only when +# instantiating the provider for the daily e2e sweep. +_CAPPED_PROVIDERS: set[type] = { + _VLSUMultimodalDataset, + _MSTSDataset, + _MSTSGermanDataset, + _MSTSRussianDataset, + _MSTSChineseDataset, + _MSTSHindiDataset, + _MSTSSpanishDataset, + _MSTSItalianDataset, + _MSTSFrenchDataset, + _MSTSKoreanDataset, + _MSTSArabicDataset, + _MSTSFarsiDataset, + _VLGuardDataset, + _VLGuardSafeUnsafesDataset, + _VLGuardSafeSafesDataset, +} +_CAPPED_MAX_EXAMPLES = 6 + def get_dataset_providers(): """Helper to get all registered providers for parameterization.""" @@ -47,6 +83,13 @@ def get_dataset_providers(): return [(name, cls) for name, cls in providers.items()] +def _instantiate_provider(provider_cls): + """Instantiate a provider for e2e, applying CI-only row caps to image-heavy variants.""" + if provider_cls in _CAPPED_PROVIDERS: + return provider_cls(max_examples=_CAPPED_MAX_EXAMPLES) + return provider_cls() + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=5, min=5, max=60), @@ -88,8 +131,7 @@ async def test_fetch_dataset(self, name, provider_cls): logger.info(f"Testing provider: {name}") try: - # Limit examples for slow multimodal providers that fetch many remote images - provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() + provider = _instantiate_provider(provider_cls) dataset = await _fetch_with_retry(provider) except Exception as e: diff --git a/tests/unit/datasets/test_aya_redteaming_dataset.py b/tests/unit/datasets/test_aya_redteaming_dataset.py index 6cc4fa19d..8d4d214be 100644 --- a/tests/unit/datasets/test_aya_redteaming_dataset.py +++ b/tests/unit/datasets/test_aya_redteaming_dataset.py @@ -5,7 +5,16 @@ import pytest -from pyrit.datasets.seed_datasets.remote.aya_redteaming_dataset import _AyaRedteamingDataset +from pyrit.datasets.seed_datasets.remote.aya_redteaming_dataset import ( + _AyaRedteamingArabicDataset, + _AyaRedteamingDataset, + _AyaRedteamingFrenchDataset, + _AyaRedteamingHindiDataset, + _AyaRedteamingRussianDataset, + _AyaRedteamingSerbianDataset, + _AyaRedteamingSpanishDataset, + _AyaRedteamingTagalogDataset, +) from pyrit.models import SeedDataset, SeedPrompt @@ -60,9 +69,47 @@ async def test_fetch_dataset_filters_by_harm_scope(mock_aya_data): def test_dataset_name(): loader = _AyaRedteamingDataset() - assert loader.dataset_name == "aya_redteaming" + assert loader.dataset_name == "aya_redteaming_english" def test_language_code_mapping(): loader = _AyaRedteamingDataset(language="French") assert "fra" in loader.source + + +# Sibling-subclass variants (one registered provider per non-English language). +_AYA_LANGUAGE_VARIANTS = [ + (_AyaRedteamingHindiDataset, "aya_redteaming_hindi", "hin"), + (_AyaRedteamingFrenchDataset, "aya_redteaming_french", "fra"), + (_AyaRedteamingSpanishDataset, "aya_redteaming_spanish", "spa"), + (_AyaRedteamingArabicDataset, "aya_redteaming_arabic", "arb"), + (_AyaRedteamingRussianDataset, "aya_redteaming_russian", "rus"), + (_AyaRedteamingSerbianDataset, "aya_redteaming_serbian", "srp"), + (_AyaRedteamingTagalogDataset, "aya_redteaming_tagalog", "tgl"), +] + + +@pytest.mark.parametrize("cls,expected_name,expected_lang_code", _AYA_LANGUAGE_VARIANTS) +def test_language_variant_dataset_name_and_source(cls, expected_name, expected_lang_code): + loader = cls() + assert loader.dataset_name == expected_name + assert f"aya_{expected_lang_code}.jsonl" in loader.source + + +@pytest.mark.parametrize("cls,expected_name,expected_lang_code", _AYA_LANGUAGE_VARIANTS) +async def test_language_variant_fetch_dataset(cls, expected_name, expected_lang_code, mock_aya_data): + loader = cls() + with patch.object(loader, "_fetch_from_url", return_value=mock_aya_data): + dataset = await loader.fetch_dataset_async() + assert dataset.dataset_name == expected_name + assert len(dataset.seeds) == 2 + for seed in dataset.seeds: + assert isinstance(seed, SeedPrompt) + assert seed.dataset_name == expected_name + + +@pytest.mark.parametrize("cls,expected_name,_lang_code", _AYA_LANGUAGE_VARIANTS) +def test_language_variant_filters_forward(cls, expected_name, _lang_code, mock_aya_data): + loader = cls(harm_categories=["Hate Speech"], harm_scope="global") + assert loader.harm_categories_filter == ["Hate Speech"] + assert loader.harm_scope == "global" diff --git a/tests/unit/datasets/test_babelscape_alert_dataset.py b/tests/unit/datasets/test_babelscape_alert_dataset.py index 476867762..b9647e17f 100644 --- a/tests/unit/datasets/test_babelscape_alert_dataset.py +++ b/tests/unit/datasets/test_babelscape_alert_dataset.py @@ -5,7 +5,10 @@ import pytest -from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import _BabelscapeAlertDataset +from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import ( + _BabelscapeAlertDataset, + _BabelscapeAlertOriginalDataset, +) from pyrit.models import SeedDataset, SeedPrompt @@ -65,9 +68,31 @@ async def test_fetch_dataset_includes_harm_categories(self, mock_alert_data): def test_dataset_name(self): """Test dataset_name property.""" loader = _BabelscapeAlertDataset() - assert loader.dataset_name == "babelscape_alert" + assert loader.dataset_name == "babelscape_alert_adversarial" def test_invalid_category_raises_error(self): """Test that invalid category raises ValueError.""" with pytest.raises(ValueError): _BabelscapeAlertDataset(category="invalid_category") + + +class TestBabelscapeAlertOriginalDataset: + """Sibling loader pinned to the original (smaller) 'alert' config.""" + + def test_dataset_name(self): + loader = _BabelscapeAlertOriginalDataset() + assert loader.dataset_name == "babelscape_alert_original" + + def test_pins_alert_category(self): + loader = _BabelscapeAlertOriginalDataset() + assert loader.category == "alert" + + async def test_fetch_dataset_passes_alert_config(self, mock_alert_data): + loader = _BabelscapeAlertOriginalDataset() + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_alert_data)) as mock_fetch: + dataset = await loader.fetch_dataset_async() + + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == "babelscape_alert_original" + assert mock_fetch.await_count == 1 + assert mock_fetch.await_args.kwargs["config"] == "alert" diff --git a/tests/unit/datasets/test_categorical_harmful_qa_dataset.py b/tests/unit/datasets/test_categorical_harmful_qa_dataset.py index c46205fb4..c322ff1da 100644 --- a/tests/unit/datasets/test_categorical_harmful_qa_dataset.py +++ b/tests/unit/datasets/test_categorical_harmful_qa_dataset.py @@ -6,7 +6,9 @@ import pytest from pyrit.datasets.seed_datasets.remote.categorical_harmful_qa_dataset import ( + _CategoricalHarmfulQAChineseDataset, _CategoricalHarmfulQADataset, + _CategoricalHarmfulQAVietnameseDataset, ) from pyrit.models import SeedDataset, SeedObjective @@ -51,7 +53,7 @@ async def test_fetch_dataset_default_english(self, mock_catqa_data): assert first.harm_categories == ["Illegal Activity"] assert first.metadata["subcategory"] == "Drug" assert first.metadata["language"] == "en" - assert first.dataset_name == "categorical_harmful_qa" + assert first.dataset_name == "categorical_harmful_qa_english" third = dataset.seeds[2] assert third.harm_categories == ["Fraud/Deception"] @@ -90,4 +92,31 @@ async def test_fetch_dataset_with_empty_category(self): def test_dataset_name(self): loader = _CategoricalHarmfulQADataset() - assert loader.dataset_name == "categorical_harmful_qa" + assert loader.dataset_name == "categorical_harmful_qa_english" + + +_CATQA_LANGUAGE_VARIANTS = [ + (_CategoricalHarmfulQAChineseDataset, "categorical_harmful_qa_chinese", "zh"), + (_CategoricalHarmfulQAVietnameseDataset, "categorical_harmful_qa_vietnamese", "vi"), +] + + +@pytest.mark.parametrize("cls,expected_name,expected_lang", _CATQA_LANGUAGE_VARIANTS) +def test_language_variant_dataset_name_and_pinned_split(cls, expected_name, expected_lang): + loader = cls() + assert loader.dataset_name == expected_name + assert loader.language == expected_lang + + +@pytest.mark.parametrize("cls,expected_name,expected_lang", _CATQA_LANGUAGE_VARIANTS) +async def test_language_variant_fetch_passes_pinned_split(cls, expected_name, expected_lang, mock_catqa_data): + loader = cls() + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_catqa_data)) as mock_fetch: + dataset = await loader.fetch_dataset_async() + + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == expected_name + assert mock_fetch.await_args.kwargs["split"] == expected_lang + for seed in dataset.seeds: + assert seed.metadata["language"] == expected_lang + assert seed.dataset_name == expected_name diff --git a/tests/unit/datasets/test_hixstest_dataset.py b/tests/unit/datasets/test_hixstest_dataset.py index d9b5e982a..e3d52d48f 100644 --- a/tests/unit/datasets/test_hixstest_dataset.py +++ b/tests/unit/datasets/test_hixstest_dataset.py @@ -8,6 +8,7 @@ from pyrit.datasets.seed_datasets.remote.hixstest_dataset import ( HiXSTestLanguage, _HiXSTestDataset, + _HiXSTestEnglishDataset, ) from pyrit.models import SeedDataset, SeedPrompt @@ -37,7 +38,7 @@ class TestHiXSTestDataset: def test_dataset_name(self): """Test dataset_name property.""" loader = _HiXSTestDataset() - assert loader.dataset_name == "hixstest" + assert loader.dataset_name == "hixstest_hindi" def test_init_defaults_to_hindi(self): """Default language is Hindi.""" @@ -165,3 +166,29 @@ async def test_fetch_dataset_missing_english_field_raises(self): with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=bad_data)): with pytest.raises(ValueError, match="missing required field 'english_prompt'"): await loader.fetch_dataset_async() + + +class TestHiXSTestEnglishDataset: + """Sibling loader pinned to the English column of HiXSTest.""" + + def test_dataset_name(self): + loader = _HiXSTestEnglishDataset() + assert loader.dataset_name == "hixstest_english" + + def test_pins_english_language(self): + loader = _HiXSTestEnglishDataset() + assert loader.language is HiXSTestLanguage.ENGLISH + + def test_token_forwarded(self): + loader = _HiXSTestEnglishDataset(token="explicit-token") + assert loader.token == "explicit-token" + + async def test_fetch_dataset_uses_english_column(self, mock_hixstest_data): + loader = _HiXSTestEnglishDataset() + with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_hixstest_data)): + dataset = await loader.fetch_dataset_async() + + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == "hixstest_english" + assert dataset.seeds[0].value == "How can I kill my neighbour?" + assert dataset.seeds[0].metadata["language"] == "en" diff --git a/tests/unit/datasets/test_msts_dataset.py b/tests/unit/datasets/test_msts_dataset.py index 894b1796e..209a156ab 100644 --- a/tests/unit/datasets/test_msts_dataset.py +++ b/tests/unit/datasets/test_msts_dataset.py @@ -7,7 +7,20 @@ import pytest -from pyrit.datasets.seed_datasets.remote.msts_dataset import _MSTSDataset +from pyrit.datasets.seed_datasets.remote.msts_dataset import ( + _LANGUAGE_TO_SPLIT, + _MSTSArabicDataset, + _MSTSChineseDataset, + _MSTSDataset, + _MSTSFarsiDataset, + _MSTSFrenchDataset, + _MSTSGermanDataset, + _MSTSHindiDataset, + _MSTSItalianDataset, + _MSTSKoreanDataset, + _MSTSRussianDataset, + _MSTSSpanishDataset, +) from pyrit.models import SeedDataset @@ -541,3 +554,59 @@ async def test_fetch_and_save_image_continues_when_path_exists_raises(tmp_path): mock_serializer.save_data.assert_awaited_once_with(data=b"recovered-bytes", output_filename="msts_img_0003") assert result == str(Path(str(tmp_path) + "/images", "msts_img_0003.jpg")) + + +# Sibling-subclass variants (one registered provider per non-English language). +_MSTS_LANGUAGE_VARIANTS = [ + (_MSTSGermanDataset, "msts_german", "de"), + (_MSTSRussianDataset, "msts_russian", "ru"), + (_MSTSChineseDataset, "msts_chinese", "zh"), + (_MSTSHindiDataset, "msts_hindi", "hi"), + (_MSTSSpanishDataset, "msts_spanish", "es"), + (_MSTSItalianDataset, "msts_italian", "it"), + (_MSTSFrenchDataset, "msts_french", "fr"), + (_MSTSKoreanDataset, "msts_korean", "ko"), + (_MSTSArabicDataset, "msts_arabic", "ar"), + (_MSTSFarsiDataset, "msts_farsi", "fa"), +] + + +@pytest.mark.parametrize("cls,expected_name,expected_lang", _MSTS_LANGUAGE_VARIANTS) +def test_language_variant_dataset_name_and_pinned_language(cls, expected_name, expected_lang): + loader = cls() + assert loader.dataset_name == expected_name + assert loader.languages == [expected_lang] + + +@pytest.mark.parametrize("cls,expected_name,expected_lang", _MSTS_LANGUAGE_VARIANTS) +def test_language_variant_max_examples_forwarded(cls, expected_name, expected_lang): + loader = cls(max_examples=4) + assert loader.max_examples == 4 + assert loader.dataset_name == expected_name + + +@pytest.mark.parametrize("cls,expected_name,expected_lang", _MSTS_LANGUAGE_VARIANTS) +def test_language_variant_text_modifiers_forwarded(cls, expected_name, expected_lang): + loader = cls(text_modifiers=["assistance"]) + assert loader.text_modifiers == ["assistance"] + + +@pytest.mark.parametrize("cls,expected_name,expected_lang", _MSTS_LANGUAGE_VARIANTS) +async def test_language_variant_fetch_passes_pinned_language(cls, expected_name, expected_lang, english_rows): + loader = cls() + with ( + patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=english_rows)) as mock_fetch, + patch.object( + loader, + "_fetch_and_save_image_async", + new=AsyncMock(return_value="/tmp/msts.jpg"), + ), + ): + dataset = await loader.fetch_dataset_async() + + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == expected_name + # _fetch_from_huggingface is called once per language; for the sibling subclasses + # that means exactly one call with the pinned split. + assert mock_fetch.await_count == 1 + assert mock_fetch.await_args.kwargs["split"] == _LANGUAGE_TO_SPLIT[expected_lang] diff --git a/tests/unit/datasets/test_salad_bench_dataset.py b/tests/unit/datasets/test_salad_bench_dataset.py index 62924dbe6..792546e7e 100644 --- a/tests/unit/datasets/test_salad_bench_dataset.py +++ b/tests/unit/datasets/test_salad_bench_dataset.py @@ -5,7 +5,11 @@ import pytest -from pyrit.datasets.seed_datasets.remote.salad_bench_dataset import _SaladBenchDataset +from pyrit.datasets.seed_datasets.remote.salad_bench_dataset import ( + _SaladBenchAttackEnhancedDataset, + _SaladBenchDataset, + _SaladBenchDefenseEnhancedDataset, +) from pyrit.models import SeedDataset, SeedPrompt @@ -43,7 +47,7 @@ async def test_fetch_dataset(self, mock_salad_bench_data): first_prompt = dataset.seeds[0] assert first_prompt.value == "How to create a harmful website?" assert first_prompt.data_type == "text" - assert first_prompt.dataset_name == "salad_bench" + assert first_prompt.dataset_name == "salad_bench_base" assert first_prompt.harm_categories == ["Representation & Toxicity", "Hate Speech"] def test_parse_category(self): @@ -56,7 +60,7 @@ def test_parse_category(self): def test_dataset_name(self): """Test dataset_name property.""" loader = _SaladBenchDataset() - assert loader.dataset_name == "salad_bench" + assert loader.dataset_name == "salad_bench_base" async def test_fetch_dataset_with_custom_config(self, mock_salad_bench_data): """Test fetching with custom config.""" @@ -76,3 +80,31 @@ async def test_fetch_dataset_with_custom_config(self, mock_salad_bench_data): assert call_kwargs["dataset_name"] == "walledai/SaladBench" assert call_kwargs["config"] == "prompts" assert call_kwargs["split"] == "attackEnhanced" + + +_SALAD_SPLIT_VARIANTS = [ + (_SaladBenchAttackEnhancedDataset, "salad_bench_attack_enhanced", "attackEnhanced"), + (_SaladBenchDefenseEnhancedDataset, "salad_bench_defense_enhanced", "defenseEnhanced"), +] + + +@pytest.mark.parametrize("cls,expected_name,expected_split", _SALAD_SPLIT_VARIANTS) +def test_split_variant_dataset_name_and_pinned_split(cls, expected_name, expected_split): + loader = cls() + assert loader.dataset_name == expected_name + assert loader.split == expected_split + + +@pytest.mark.parametrize("cls,expected_name,expected_split", _SALAD_SPLIT_VARIANTS) +async def test_split_variant_fetch_passes_pinned_split(cls, expected_name, expected_split, mock_salad_bench_data): + loader = cls() + with patch.object( + loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_salad_bench_data) + ) as mock_fetch: + dataset = await loader.fetch_dataset_async() + + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == expected_name + call_kwargs = mock_fetch.call_args.kwargs + assert call_kwargs["split"] == expected_split + assert call_kwargs["dataset_name"] == "walledai/SaladBench" diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 4eb058f6a..accd22358 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -242,6 +242,82 @@ def dataset_name(self) -> str: await provider.fetch_dataset_async() +class TestLegacyDatasetNameAliases: + """Tests for the legacy dataset_name -> canonical name deprecation aliasing.""" + + @pytest.fixture + def mocked_registry(self): + mock_english = MagicMock(__name__="AyaEnglish") + mock_english.return_value.dataset_name = "aya_redteaming_english" + mock_english.return_value._parse_metadata = AsyncMock(return_value=None) + mock_english.return_value.fetch_dataset_async = AsyncMock( + return_value=SeedDataset( + seeds=[SeedPrompt(value="x", data_type="text")], + dataset_name="aya_redteaming_english", + ) + ) + + mock_other = MagicMock(__name__="OtherDataset") + mock_other.return_value.dataset_name = "other_dataset" + mock_other.return_value._parse_metadata = AsyncMock(return_value=None) + mock_other.return_value.fetch_dataset_async = AsyncMock( + return_value=SeedDataset( + seeds=[SeedPrompt(value="y", data_type="text")], + dataset_name="other_dataset", + ) + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"AyaEnglish": mock_english, "OtherDataset": mock_other}, + clear=True, + ): + yield + + async def test_legacy_name_resolves_with_warning(self, mocked_registry): + """Legacy 'aya_redteaming' resolves to 'aya_redteaming_english' with a DeprecationWarning.""" + with pytest.warns(DeprecationWarning, match="'aya_redteaming' is deprecated"): + datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["aya_redteaming"]) + assert len(datasets) == 1 + assert datasets[0].dataset_name == "aya_redteaming_english" + + async def test_canonical_name_does_not_warn(self, mocked_registry): + """Canonical 'aya_redteaming_english' fetches without any DeprecationWarning.""" + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["aya_redteaming_english"]) + assert len(datasets) == 1 + + async def test_mixed_legacy_and_canonical_resolves(self, mocked_registry): + """Mixing legacy and current names resolves only the legacy one.""" + with pytest.warns(DeprecationWarning, match="'aya_redteaming' is deprecated"): + datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["aya_redteaming", "other_dataset"]) + names = {d.dataset_name for d in datasets} + assert names == {"aya_redteaming_english", "other_dataset"} + + async def test_unknown_name_still_raises_value_error(self, mocked_registry): + """Names that are neither legacy nor current still raise ValueError.""" + with pytest.raises(ValueError, match=r"Dataset\(s\) not found: \['nonexistent'\]"): + await SeedDatasetProvider.fetch_datasets_async(dataset_names=["nonexistent"]) + + @pytest.mark.parametrize( + "legacy,canonical", + [ + ("aya_redteaming", "aya_redteaming_english"), + ("babelscape_alert", "babelscape_alert_adversarial"), + ("categorical_harmful_qa", "categorical_harmful_qa_english"), + ("hixstest", "hixstest_hindi"), + ("salad_bench", "salad_bench_base"), + ("vlguard", "vlguard_unsafes"), + ], + ) + def test_alias_map_entries(self, legacy, canonical): + """The alias map covers every renamed dataset and resolves to the canonical name.""" + from pyrit.datasets.seed_datasets.seed_dataset_provider import _LEGACY_DATASET_NAME_ALIASES + + assert _LEGACY_DATASET_NAME_ALIASES[legacy] == canonical + + class TestHarmBenchDataset: """Test the HarmBench dataset loader.""" diff --git a/tests/unit/datasets/test_vlguard_dataset.py b/tests/unit/datasets/test_vlguard_dataset.py index 7cc8c3506..3610bdc6f 100644 --- a/tests/unit/datasets/test_vlguard_dataset.py +++ b/tests/unit/datasets/test_vlguard_dataset.py @@ -12,6 +12,8 @@ VLGuardCategory, VLGuardSubset, _VLGuardDataset, + _VLGuardSafeSafesDataset, + _VLGuardSafeUnsafesDataset, ) from pyrit.models import SeedDataset, SeedPrompt @@ -72,7 +74,7 @@ class TestVLGuardDataset: def test_dataset_name(self): """Test dataset_name property.""" loader = _VLGuardDataset() - assert loader.dataset_name == "vlguard" + assert loader.dataset_name == "vlguard_unsafes" def test_default_subset_is_unsafes(self): """Test default subset is UNSAFES.""" @@ -399,3 +401,45 @@ def mock_hf_download(*, repo_id, filename, repo_type, local_dir, token): assert metadata == test_metadata assert result_dir == cache_dir / "test" + + +_VLGUARD_SUBSET_VARIANTS = [ + (_VLGuardSafeUnsafesDataset, "vlguard_safe_unsafes", VLGuardSubset.SAFE_UNSAFES), + (_VLGuardSafeSafesDataset, "vlguard_safe_safes", VLGuardSubset.SAFE_SAFES), +] + + +@pytest.mark.parametrize("cls,expected_name,expected_subset", _VLGUARD_SUBSET_VARIANTS) +def test_subset_variant_dataset_name_and_pinned_subset(cls, expected_name, expected_subset): + loader = cls() + assert loader.dataset_name == expected_name + assert loader.subset == expected_subset + + +@pytest.mark.parametrize("cls,expected_name,expected_subset", _VLGUARD_SUBSET_VARIANTS) +def test_subset_variant_max_examples_and_token_forwarded(cls, expected_name, expected_subset): + loader = cls(max_examples=4, token="explicit-token") + assert loader.max_examples == 4 + assert loader.token == "explicit-token" + assert loader.dataset_name == expected_name + + +@pytest.mark.parametrize("cls,expected_name,expected_subset", _VLGUARD_SUBSET_VARIANTS) +async def test_subset_variant_fetch(cls, expected_name, expected_subset, mock_vlguard_metadata, tmp_path): + image_dir = tmp_path / "test" + image_dir.mkdir() + (image_dir / "safe_001.jpg").write_bytes(b"fake image") + + loader = cls() + with patch.object( + loader, + "_download_dataset_files_async", + new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), + ): + dataset = await loader.fetch_dataset_async() + + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == expected_name + text_prompts = [p for p in dataset.seeds if p.data_type == "text"] + assert len(text_prompts) >= 1 + assert text_prompts[0].metadata["subset"] == expected_subset.value