From 83b40f405c6e6351d7cef88f8eca2f8428149fda Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 22 May 2026 19:20:29 -0700 Subject: [PATCH] MAINT: remove unused per-loader max_examples knobs Drops the per-loader max_examples/max_prompts row-limit parameters from five remote seed dataset loaders that have no callers anywhere in pyrit/, tests/, or doc/: comic_jailbreak, msts, promptintel, visual_leak_bench, vlguard. DatasetConfiguration.max_dataset_size already provides group-aware random sampling at the scenario layer, so the per-loader knobs were a redundant (and deterministic-prefix, footgun-prone) second sampling primitive. They were also unreachable from SeedDatasetProvider.fetch_datasets_async, which instantiates each loader with no args. vlsu_multimodal is intentionally left alone because its max_examples is the sole non-self caller anywhere (tests/end_to_end/test_all_datasets.py). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../remote/comic_jailbreak_dataset.py | 14 +------- .../seed_datasets/remote/msts_dataset.py | 19 ---------- .../remote/promptintel_dataset.py | 12 ------- .../remote/visual_leak_bench_dataset.py | 8 ----- .../seed_datasets/remote/vlguard_dataset.py | 8 ----- .../datasets/test_comic_jailbreak_dataset.py | 15 -------- tests/unit/datasets/test_msts_dataset.py | 36 ------------------- .../unit/datasets/test_promptintel_dataset.py | 10 ------ .../test_visual_leak_bench_dataset.py | 24 ------------- tests/unit/datasets/test_vlguard_dataset.py | 19 ---------- 10 files changed, 1 insertion(+), 164 deletions(-) diff --git a/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py b/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py index b4d43bca0..d4f748f1f 100644 --- a/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py @@ -105,7 +105,6 @@ def __init__( ), source_type: Literal["public_url", "file"] = "public_url", templates: list[str] | None = None, - max_examples: int | None = None, ) -> None: """ Initialize the ComicJailbreak dataset loader. @@ -115,8 +114,6 @@ def __init__( at a pinned commit. source_type: The type of source ('public_url' or 'file'). templates: List of template names to include. If None, all 5 templates are used. - max_examples: Maximum number of goal×template pairs to produce. If None, all - combinations are returned. Raises: ValueError: If any template name is invalid. @@ -124,7 +121,6 @@ def __init__( self.source = source self.source_type: Literal["public_url", "file"] = source_type self.templates = templates or list(self.TEMPLATE_NAMES) - self.max_examples = max_examples invalid = set(self.templates) - set(self.TEMPLATE_NAMES) if invalid: @@ -170,7 +166,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: template_paths[template_name] = await self._fetch_template_async(template_name) seeds: list[Seed] = [] - pair_count = 0 for row_idx, example in enumerate(examples): missing_keys = required_keys - example.keys() @@ -208,15 +203,8 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: behavior=example.get("Behavior", ""), ) seeds.extend(pair) - pair_count += 1 - if self.max_examples is not None and pair_count >= self.max_examples: - break - - if self.max_examples is not None and pair_count >= self.max_examples: - break - - logger.info(f"Successfully loaded {len(seeds)} seeds ({pair_count} groups) from ComicJailbreak dataset") + logger.info(f"Successfully loaded {len(seeds)} seeds from ComicJailbreak dataset") return SeedDataset(seeds=seeds, dataset_name=self.dataset_name) def _build_seed_group( diff --git a/pyrit/datasets/seed_datasets/remote/msts_dataset.py b/pyrit/datasets/seed_datasets/remote/msts_dataset.py index cc854aa66..bb9038f17 100644 --- a/pyrit/datasets/seed_datasets/remote/msts_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/msts_dataset.py @@ -134,7 +134,6 @@ def __init__( *, languages: list[str] | None = None, text_modifiers: list[str] | None = None, - max_examples: int | None = None, token: str | None = None, ) -> None: """ @@ -146,11 +145,6 @@ def __init__( language. Defaults to ``["en"]``. text_modifiers (list[str] | None): Subset of {"assistance", "intention"} to include. Defaults to both. - max_examples (int | None): Maximum number of test pairs to successfully load - across all languages and modifiers combined. Rows whose image fetch fails - are skipped and do NOT count toward this limit, so a request for N examples - returns up to N pairs as long as enough source rows succeed. Each loaded - pair produces 2 SeedPrompts (image + text). Defaults to None (no limit). token (str | None): Optional HuggingFace authentication token. Raises: @@ -158,7 +152,6 @@ def __init__( """ self.languages = self._resolve_languages(languages) self.text_modifiers = self._resolve_text_modifiers(text_modifiers) - self.max_examples = max_examples self.token = token self.source = f"https://huggingface.co/datasets/{_HF_REPO_ID}" @@ -175,10 +168,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: and both have ``sequence=0``, so they are delivered together as a single multimodal user message (image + text) rather than as two separate turns. - When ``max_examples`` is set, only successfully loaded pairs count toward the - limit; rows whose image fetch fails are logged and skipped without consuming - the quota. - Args: cache (bool): Whether to cache the fetched dataset and images. Defaults to True. @@ -189,7 +178,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: prompts: list[SeedPrompt] = [] failed_image_count = 0 - successful_pairs = 0 for language in self.languages: split_name = _LANGUAGE_TO_SPLIT[language] @@ -201,9 +189,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: ) for row in split_data: - if self.max_examples is not None and successful_pairs >= self.max_examples: - break - if row.get("prompt_type") not in self.text_modifiers: continue @@ -218,10 +203,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: continue prompts.extend(pair) - successful_pairs += 1 - - if self.max_examples is not None and successful_pairs >= self.max_examples: - break if failed_image_count > 0: logger.warning(f"[MSTS] Skipped {failed_image_count} example(s) due to image fetch failures") diff --git a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py index 7af64c4fe..51e99d35a 100644 --- a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py @@ -75,7 +75,6 @@ def __init__( severity: Optional[PromptIntelSeverity] = None, categories: Optional[list[PromptIntelCategory]] = None, search: Optional[str] = None, - max_prompts: Optional[int] = None, ) -> None: """ Initialize the PromptIntel dataset loader. @@ -87,7 +86,6 @@ def __init__( When multiple categories are specified, separate API requests are made for each category and results are merged with deduplication. search: Search term to filter prompts by title and content. Defaults to None. - max_prompts: Maximum number of prompts to fetch. Defaults to None (all available). Raises: ValueError: If an invalid severity or category is provided. @@ -103,7 +101,6 @@ def __init__( self._severity = severity self._categories = categories self._search = search - self._max_prompts = max_prompts self.source = "https://promptintel.novahunting.ai" @property @@ -177,21 +174,12 @@ def _fetch_all_prompts(self) -> list[dict[str, Any]]: seen_ids.add(record_id) all_prompts.append(record) - # Check if we've reached the max_prompts limit - if self._max_prompts and len(all_prompts) >= self._max_prompts: - all_prompts = all_prompts[: self._max_prompts] - break - # Check if there are more pages total_pages = pagination.get("pages", 1) if page >= total_pages: break page += 1 - # Also break the outer loop if max_prompts reached - if self._max_prompts and len(all_prompts) >= self._max_prompts: - break - return all_prompts def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]: diff --git a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py index 2edd921eb..2a74a2d2d 100644 --- a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py @@ -84,7 +84,6 @@ def __init__( source_type: Literal["public_url", "file"] = "public_url", categories: Optional[list[VisualLeakBenchCategory]] = None, pii_types: Optional[list[VisualLeakBenchPIIType]] = None, - max_examples: Optional[int] = None, ) -> None: """ Initialize the VisualLeakBench dataset loader. @@ -98,9 +97,6 @@ def __init__( VisualLeakBenchCategory.PII_LEAKAGE. pii_types: List of PII types to include (only relevant for PII_LEAKAGE category). If None, all PII types are included. - max_examples: Maximum number of examples to fetch. Each example produces 2 prompts - (image + text). If None, fetches all examples. Useful for testing or quick - validations. Raises: ValueError: If any of the specified categories or pii_types are invalid. @@ -109,7 +105,6 @@ def __init__( self.source_type: Literal["public_url", "file"] = source_type self.categories = categories self.pii_types = pii_types - self.max_examples = max_examples if categories is not None: self._validate_enums(categories, VisualLeakBenchCategory, "category") @@ -170,9 +165,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: prompts.extend(pair) - if self.max_examples is not None and len(prompts) >= self.max_examples * 2: - break - if failed_image_count > 0: logger.warning(f"[VisualLeakBench] Skipped {failed_image_count} image(s) due to fetch failures") diff --git a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py index 1ab86cfe7..ca7c79b77 100644 --- a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py @@ -98,7 +98,6 @@ def __init__( *, subset: VLGuardSubset = VLGuardSubset.UNSAFES, categories: list[VLGuardCategory] | None = None, - max_examples: int | None = None, token: str | None = None, ) -> None: """ @@ -108,8 +107,6 @@ def __init__( subset (VLGuardSubset): Which evaluation subset to load. Defaults to UNSAFES. categories (list[VLGuardCategory] | None): List of VLGuard categories to filter by. If None, all categories are included. - max_examples (int | None): Maximum number of multimodal examples to fetch. Each example - produces 2 prompts (text + image). If None, fetches all examples. token (str | None): HuggingFace authentication token for accessing the gated dataset. If None, uses the default token from the environment or HuggingFace CLI login. @@ -118,7 +115,6 @@ def __init__( """ self.subset = subset self.categories = categories - self.max_examples = max_examples self.token = token self.source = f"https://huggingface.co/datasets/{_HF_REPO_ID}" @@ -231,10 +227,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: prompts.append(text_prompt) prompts.append(image_prompt) - # len(prompts) is divided by two since each example produces one image and one text prompt. - if self.max_examples is not None and len(prompts) >= self.max_examples * 2: - break - logger.info(f"Successfully loaded {len(prompts)} prompts from VLGuard dataset ({self.subset.value})") return SeedDataset(seeds=prompts, dataset_name=self.dataset_name) diff --git a/tests/unit/datasets/test_comic_jailbreak_dataset.py b/tests/unit/datasets/test_comic_jailbreak_dataset.py index f818996ac..ad9c3e96e 100644 --- a/tests/unit/datasets/test_comic_jailbreak_dataset.py +++ b/tests/unit/datasets/test_comic_jailbreak_dataset.py @@ -116,21 +116,6 @@ async def test_fetch_dataset_multiple_templates(self): # 3 templates with text × 1 goal = 3 groups × 3 seeds = 9 assert len(dataset.seeds) == 9 - async def test_fetch_dataset_max_examples(self): - """max_examples limits the number of pairs produced.""" - mock_data = [_make_example(), _make_example(Goal="Another harmful goal")] - loader = _ComicJailbreakDataset(templates=["article", "speech", "message"], max_examples=2) - - with ( - patch.object(loader, "_fetch_from_url", return_value=mock_data), - patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), - patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"), - ): - dataset = await loader.fetch_dataset_async(cache=False) - - # max_examples=2 → at most 2 groups × 3 seeds = 6 - assert len(dataset.seeds) <= 6 - async def test_fetch_dataset_metadata(self): """Metadata contains goal, template, and behavior.""" mock_data = [_make_example()] diff --git a/tests/unit/datasets/test_msts_dataset.py b/tests/unit/datasets/test_msts_dataset.py index 894b1796e..704e0b9b5 100644 --- a/tests/unit/datasets/test_msts_dataset.py +++ b/tests/unit/datasets/test_msts_dataset.py @@ -234,42 +234,6 @@ async def test_failed_image_is_skipped(english_rows): assert len(dataset.seeds) == 6 -async def test_max_examples_limits_rows(english_rows): - loader = _MSTSDataset(max_examples=1) - - with ( - patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=english_rows)), - patch.object( - loader, - "_fetch_and_save_image_async", - new=AsyncMock(return_value="/tmp/msts.jpg"), - ), - ): - dataset = await loader.fetch_dataset_async() - - # Only 1 row processed → 2 prompts - assert len(dataset.seeds) == 2 - - -async def test_max_examples_does_not_count_failed_rows(english_rows): - loader = _MSTSDataset(max_examples=2) - - # First row fails, next three succeed. max_examples=2 should still return 2 pairs - # (4 prompts) by skipping the failed row, not 1 pair (2 prompts). - with ( - patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=english_rows)), - patch.object( - loader, - "_fetch_and_save_image_async", - new=AsyncMock(side_effect=[Exception("network error"), "/tmp/a.jpg", "/tmp/b.jpg", "/tmp/c.jpg"]), - ), - ): - dataset = await loader.fetch_dataset_async() - - # 2 successful pairs × 2 prompts each = 4 prompts (failed row did not consume quota) - assert len(dataset.seeds) == 4 - - async def test_metadata_includes_msts_fields(english_rows): loader = _MSTSDataset() diff --git a/tests/unit/datasets/test_promptintel_dataset.py b/tests/unit/datasets/test_promptintel_dataset.py index 8c43bd56e..1c0d2ebd7 100644 --- a/tests/unit/datasets/test_promptintel_dataset.py +++ b/tests/unit/datasets/test_promptintel_dataset.py @@ -294,16 +294,6 @@ async def test_pagination_fetches_all_pages(self, api_key): assert len(dataset.seeds) == 2 # 1 prompt from page1 + 1 from page2 = 2 SeedPrompts - async def test_max_prompts_limits_results(self, api_key, mock_promptintel_response): - loader = _PromptIntelDataset(api_key=api_key, max_prompts=1) - mock_resp = _make_mock_response(json_data=mock_promptintel_response) - - with patch("requests.get", return_value=mock_resp): - dataset = await loader.fetch_dataset_async() - - # max_prompts=1 should limit to 1 SeedPrompt - assert len(dataset.seeds) == 1 - class TestPromptIntelDatasetAPIErrors: """Test error handling for API failures.""" diff --git a/tests/unit/datasets/test_visual_leak_bench_dataset.py b/tests/unit/datasets/test_visual_leak_bench_dataset.py index a425e6b3d..f7bd59737 100644 --- a/tests/unit/datasets/test_visual_leak_bench_dataset.py +++ b/tests/unit/datasets/test_visual_leak_bench_dataset.py @@ -49,7 +49,6 @@ def test_init_defaults(self): dataset = _VisualLeakBenchDataset() assert dataset.categories is None assert dataset.pii_types is None - assert dataset.max_examples is None def test_init_with_categories(self): """Test initialization with category filtering.""" @@ -83,11 +82,6 @@ def test_init_rejects_raw_string_matching_enum_value_for_pii_types(self): with pytest.raises(ValueError, match="Expected VisualLeakBenchPIIType"): _VisualLeakBenchDataset(pii_types=["Email"]) - def test_init_with_max_examples(self): - """Test initialization with max_examples.""" - dataset = _VisualLeakBenchDataset(max_examples=10) - assert dataset.max_examples == 10 - async def test_fetch_dataset_ocr_creates_pair(self): """Test that OCR Injection example creates an image+text pair.""" mock_data = [_make_ocr_example()] @@ -219,24 +213,6 @@ async def test_pii_type_filter_does_not_affect_ocr(self): categories = [seed.harm_categories for seed in dataset.seeds] assert any("ocr_injection" in cats for cats in categories) - async def test_max_examples_limits_output(self): - """Test that max_examples limits the number of examples returned.""" - mock_data = [ - _make_ocr_example(filename="ocr_v2_0000.png"), - _make_ocr_example(filename="ocr_v2_0001.png"), - _make_ocr_example(filename="ocr_v2_0002.png"), - ] - loader = _VisualLeakBenchDataset(max_examples=2) - - with ( - patch.object(loader, "_fetch_from_url", return_value=mock_data), - patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), - ): - dataset = await loader.fetch_dataset_async(cache=False) - - # max_examples=2 → at most 4 prompts (2 pairs) - assert len(dataset.seeds) <= 4 - async def test_all_images_fail_produces_empty_dataset(self): """Test that when all image downloads fail, no prompts are produced and SeedDataset raises.""" mock_data = [_make_ocr_example()] diff --git a/tests/unit/datasets/test_vlguard_dataset.py b/tests/unit/datasets/test_vlguard_dataset.py index 7cc8c3506..c13a0ba0d 100644 --- a/tests/unit/datasets/test_vlguard_dataset.py +++ b/tests/unit/datasets/test_vlguard_dataset.py @@ -181,25 +181,6 @@ async def test_category_filtering(self, mock_vlguard_metadata, tmp_path): text_prompts = [p for p in dataset.seeds if p.data_type == "text"] assert text_prompts[0].harm_categories == ["privacy"] - async def test_max_examples(self, mock_vlguard_metadata, tmp_path): - """Test that max_examples limits the number of returned examples.""" - image_dir = tmp_path / "test" - image_dir.mkdir() - (image_dir / "unsafe_001.jpg").write_bytes(b"fake image") - (image_dir / "unsafe_002.jpg").write_bytes(b"fake image") - - loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES, max_examples=1) - - with patch.object( - loader, - "_download_dataset_files_async", - new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), - ): - dataset = await loader.fetch_dataset_async() - - # max_examples=1 → 1 example × 2 prompts = 2 prompts - assert len(dataset.seeds) == 2 - async def test_prompt_group_id_links_text_and_image(self, mock_vlguard_metadata, tmp_path): """Test that text and image prompts share the same prompt_group_id.""" image_dir = tmp_path / "test"