Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(
),
source_type: Literal["public_url", "file"] = "public_url",
templates: list[str] | None = None,
max_examples: int | None = None,
) -> None:
"""
Initialize the ComicJailbreak dataset loader.
Expand All @@ -115,16 +114,13 @@ def __init__(
at a pinned commit.
source_type: The type of source ('public_url' or 'file').
templates: List of template names to include. If None, all 5 templates are used.
max_examples: Maximum number of goal×template pairs to produce. If None, all
combinations are returned.

Raises:
ValueError: If any template name is invalid.
"""
self.source = source
self.source_type: Literal["public_url", "file"] = source_type
self.templates = templates or list(self.TEMPLATE_NAMES)
self.max_examples = max_examples

invalid = set(self.templates) - set(self.TEMPLATE_NAMES)
if invalid:
Expand Down Expand Up @@ -170,7 +166,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
template_paths[template_name] = await self._fetch_template_async(template_name)

seeds: list[Seed] = []
pair_count = 0

for row_idx, example in enumerate(examples):
missing_keys = required_keys - example.keys()
Expand Down Expand Up @@ -208,15 +203,8 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
behavior=example.get("Behavior", ""),
)
seeds.extend(pair)
pair_count += 1

if self.max_examples is not None and pair_count >= self.max_examples:
break

if self.max_examples is not None and pair_count >= self.max_examples:
break

logger.info(f"Successfully loaded {len(seeds)} seeds ({pair_count} groups) from ComicJailbreak dataset")
logger.info(f"Successfully loaded {len(seeds)} seeds from ComicJailbreak dataset")
return SeedDataset(seeds=seeds, dataset_name=self.dataset_name)

def _build_seed_group(
Expand Down
19 changes: 0 additions & 19 deletions pyrit/datasets/seed_datasets/remote/msts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
*,
languages: list[str] | None = None,
text_modifiers: list[str] | None = None,
max_examples: int | None = None,
token: str | None = None,
) -> None:
"""
Expand All @@ -146,19 +145,13 @@ def __init__(
language. Defaults to ``["en"]``.
text_modifiers (list[str] | None): Subset of {"assistance", "intention"} to include.
Defaults to both.
max_examples (int | None): Maximum number of test pairs to successfully load
across all languages and modifiers combined. Rows whose image fetch fails
are skipped and do NOT count toward this limit, so a request for N examples
returns up to N pairs as long as enough source rows succeed. Each loaded
pair produces 2 SeedPrompts (image + text). Defaults to None (no limit).
token (str | None): Optional HuggingFace authentication token.

Raises:
ValueError: If any language code is unsupported, or any text modifier is invalid.
"""
self.languages = self._resolve_languages(languages)
self.text_modifiers = self._resolve_text_modifiers(text_modifiers)
self.max_examples = max_examples
self.token = token
self.source = f"https://huggingface.co/datasets/{_HF_REPO_ID}"

Expand All @@ -175,10 +168,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
and both have ``sequence=0``, so they are delivered together as a single
multimodal user message (image + text) rather than as two separate turns.

When ``max_examples`` is set, only successfully loaded pairs count toward the
limit; rows whose image fetch fails are logged and skipped without consuming
the quota.

Args:
cache (bool): Whether to cache the fetched dataset and images. Defaults to True.

Expand All @@ -189,7 +178,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:

prompts: list[SeedPrompt] = []
failed_image_count = 0
successful_pairs = 0

for language in self.languages:
split_name = _LANGUAGE_TO_SPLIT[language]
Expand All @@ -201,9 +189,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
)

for row in split_data:
if self.max_examples is not None and successful_pairs >= self.max_examples:
break

if row.get("prompt_type") not in self.text_modifiers:
continue

Expand All @@ -218,10 +203,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
continue

prompts.extend(pair)
successful_pairs += 1

if self.max_examples is not None and successful_pairs >= self.max_examples:
break

if failed_image_count > 0:
logger.warning(f"[MSTS] Skipped {failed_image_count} example(s) due to image fetch failures")
Expand Down
12 changes: 0 additions & 12 deletions pyrit/datasets/seed_datasets/remote/promptintel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(
severity: Optional[PromptIntelSeverity] = None,
categories: Optional[list[PromptIntelCategory]] = None,
search: Optional[str] = None,
max_prompts: Optional[int] = None,
) -> None:
"""
Initialize the PromptIntel dataset loader.
Expand All @@ -87,7 +86,6 @@ def __init__(
When multiple categories are specified, separate API requests are made for each
category and results are merged with deduplication.
search: Search term to filter prompts by title and content. Defaults to None.
max_prompts: Maximum number of prompts to fetch. Defaults to None (all available).

Raises:
ValueError: If an invalid severity or category is provided.
Expand All @@ -103,7 +101,6 @@ def __init__(
self._severity = severity
self._categories = categories
self._search = search
self._max_prompts = max_prompts
self.source = "https://promptintel.novahunting.ai"

@property
Expand Down Expand Up @@ -177,21 +174,12 @@ def _fetch_all_prompts(self) -> list[dict[str, Any]]:
seen_ids.add(record_id)
all_prompts.append(record)

# Check if we've reached the max_prompts limit
if self._max_prompts and len(all_prompts) >= self._max_prompts:
all_prompts = all_prompts[: self._max_prompts]
break

# Check if there are more pages
total_pages = pagination.get("pages", 1)
if page >= total_pages:
break
page += 1

# Also break the outer loop if max_prompts reached
if self._max_prompts and len(all_prompts) >= self._max_prompts:
break

return all_prompts

def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
source_type: Literal["public_url", "file"] = "public_url",
categories: Optional[list[VisualLeakBenchCategory]] = None,
pii_types: Optional[list[VisualLeakBenchPIIType]] = None,
max_examples: Optional[int] = None,
) -> None:
"""
Initialize the VisualLeakBench dataset loader.
Expand All @@ -98,9 +97,6 @@ def __init__(
VisualLeakBenchCategory.PII_LEAKAGE.
pii_types: List of PII types to include (only relevant for PII_LEAKAGE category).
If None, all PII types are included.
max_examples: Maximum number of examples to fetch. Each example produces 2 prompts
(image + text). If None, fetches all examples. Useful for testing or quick
validations.

Raises:
ValueError: If any of the specified categories or pii_types are invalid.
Expand All @@ -109,7 +105,6 @@ def __init__(
self.source_type: Literal["public_url", "file"] = source_type
self.categories = categories
self.pii_types = pii_types
self.max_examples = max_examples

if categories is not None:
self._validate_enums(categories, VisualLeakBenchCategory, "category")
Expand Down Expand Up @@ -170,9 +165,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:

prompts.extend(pair)

if self.max_examples is not None and len(prompts) >= self.max_examples * 2:
break

if failed_image_count > 0:
logger.warning(f"[VisualLeakBench] Skipped {failed_image_count} image(s) due to fetch failures")

Expand Down
8 changes: 0 additions & 8 deletions pyrit/datasets/seed_datasets/remote/vlguard_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(
*,
subset: VLGuardSubset = VLGuardSubset.UNSAFES,
categories: list[VLGuardCategory] | None = None,
max_examples: int | None = None,
token: str | None = None,
) -> None:
"""
Expand All @@ -108,8 +107,6 @@ def __init__(
subset (VLGuardSubset): Which evaluation subset to load. Defaults to UNSAFES.
categories (list[VLGuardCategory] | None): List of VLGuard categories to filter by.
If None, all categories are included.
max_examples (int | None): Maximum number of multimodal examples to fetch. Each example
produces 2 prompts (text + image). If None, fetches all examples.
token (str | None): HuggingFace authentication token for accessing the gated dataset.
If None, uses the default token from the environment or HuggingFace CLI login.

Expand All @@ -118,7 +115,6 @@ def __init__(
"""
self.subset = subset
self.categories = categories
self.max_examples = max_examples
self.token = token
self.source = f"https://huggingface.co/datasets/{_HF_REPO_ID}"

Expand Down Expand Up @@ -231,10 +227,6 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
prompts.append(text_prompt)
prompts.append(image_prompt)

# len(prompts) is divided by two since each example produces one image and one text prompt.
if self.max_examples is not None and len(prompts) >= self.max_examples * 2:
break

logger.info(f"Successfully loaded {len(prompts)} prompts from VLGuard dataset ({self.subset.value})")

return SeedDataset(seeds=prompts, dataset_name=self.dataset_name)
Expand Down
15 changes: 0 additions & 15 deletions tests/unit/datasets/test_comic_jailbreak_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,6 @@ async def test_fetch_dataset_multiple_templates(self):
# 3 templates with text × 1 goal = 3 groups × 3 seeds = 9
assert len(dataset.seeds) == 9

async def test_fetch_dataset_max_examples(self):
"""max_examples limits the number of pairs produced."""
mock_data = [_make_example(), _make_example(Goal="Another harmful goal")]
loader = _ComicJailbreakDataset(templates=["article", "speech", "message"], max_examples=2)

with (
patch.object(loader, "_fetch_from_url", return_value=mock_data),
patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"),
patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"),
):
dataset = await loader.fetch_dataset_async(cache=False)

# max_examples=2 → at most 2 groups × 3 seeds = 6
assert len(dataset.seeds) <= 6

async def test_fetch_dataset_metadata(self):
"""Metadata contains goal, template, and behavior."""
mock_data = [_make_example()]
Expand Down
36 changes: 0 additions & 36 deletions tests/unit/datasets/test_msts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,42 +234,6 @@ async def test_failed_image_is_skipped(english_rows):
assert len(dataset.seeds) == 6


async def test_max_examples_limits_rows(english_rows):
loader = _MSTSDataset(max_examples=1)

with (
patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=english_rows)),
patch.object(
loader,
"_fetch_and_save_image_async",
new=AsyncMock(return_value="/tmp/msts.jpg"),
),
):
dataset = await loader.fetch_dataset_async()

# Only 1 row processed → 2 prompts
assert len(dataset.seeds) == 2


async def test_max_examples_does_not_count_failed_rows(english_rows):
loader = _MSTSDataset(max_examples=2)

# First row fails, next three succeed. max_examples=2 should still return 2 pairs
# (4 prompts) by skipping the failed row, not 1 pair (2 prompts).
with (
patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=english_rows)),
patch.object(
loader,
"_fetch_and_save_image_async",
new=AsyncMock(side_effect=[Exception("network error"), "/tmp/a.jpg", "/tmp/b.jpg", "/tmp/c.jpg"]),
),
):
dataset = await loader.fetch_dataset_async()

# 2 successful pairs × 2 prompts each = 4 prompts (failed row did not consume quota)
assert len(dataset.seeds) == 4


async def test_metadata_includes_msts_fields(english_rows):
loader = _MSTSDataset()

Expand Down
10 changes: 0 additions & 10 deletions tests/unit/datasets/test_promptintel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,6 @@ async def test_pagination_fetches_all_pages(self, api_key):

assert len(dataset.seeds) == 2 # 1 prompt from page1 + 1 from page2 = 2 SeedPrompts

async def test_max_prompts_limits_results(self, api_key, mock_promptintel_response):
loader = _PromptIntelDataset(api_key=api_key, max_prompts=1)
mock_resp = _make_mock_response(json_data=mock_promptintel_response)

with patch("requests.get", return_value=mock_resp):
dataset = await loader.fetch_dataset_async()

# max_prompts=1 should limit to 1 SeedPrompt
assert len(dataset.seeds) == 1


class TestPromptIntelDatasetAPIErrors:
"""Test error handling for API failures."""
Expand Down
24 changes: 0 additions & 24 deletions tests/unit/datasets/test_visual_leak_bench_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_init_defaults(self):
dataset = _VisualLeakBenchDataset()
assert dataset.categories is None
assert dataset.pii_types is None
assert dataset.max_examples is None

def test_init_with_categories(self):
"""Test initialization with category filtering."""
Expand Down Expand Up @@ -83,11 +82,6 @@ def test_init_rejects_raw_string_matching_enum_value_for_pii_types(self):
with pytest.raises(ValueError, match="Expected VisualLeakBenchPIIType"):
_VisualLeakBenchDataset(pii_types=["Email"])

def test_init_with_max_examples(self):
"""Test initialization with max_examples."""
dataset = _VisualLeakBenchDataset(max_examples=10)
assert dataset.max_examples == 10

async def test_fetch_dataset_ocr_creates_pair(self):
"""Test that OCR Injection example creates an image+text pair."""
mock_data = [_make_ocr_example()]
Expand Down Expand Up @@ -219,24 +213,6 @@ async def test_pii_type_filter_does_not_affect_ocr(self):
categories = [seed.harm_categories for seed in dataset.seeds]
assert any("ocr_injection" in cats for cats in categories)

async def test_max_examples_limits_output(self):
"""Test that max_examples limits the number of examples returned."""
mock_data = [
_make_ocr_example(filename="ocr_v2_0000.png"),
_make_ocr_example(filename="ocr_v2_0001.png"),
_make_ocr_example(filename="ocr_v2_0002.png"),
]
loader = _VisualLeakBenchDataset(max_examples=2)

with (
patch.object(loader, "_fetch_from_url", return_value=mock_data),
patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"),
):
dataset = await loader.fetch_dataset_async(cache=False)

# max_examples=2 → at most 4 prompts (2 pairs)
assert len(dataset.seeds) <= 4

async def test_all_images_fail_produces_empty_dataset(self):
"""Test that when all image downloads fail, no prompts are produced and SeedDataset raises."""
mock_data = [_make_ocr_example()]
Expand Down
19 changes: 0 additions & 19 deletions tests/unit/datasets/test_vlguard_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,6 @@ async def test_category_filtering(self, mock_vlguard_metadata, tmp_path):
text_prompts = [p for p in dataset.seeds if p.data_type == "text"]
assert text_prompts[0].harm_categories == ["privacy"]

async def test_max_examples(self, mock_vlguard_metadata, tmp_path):
"""Test that max_examples limits the number of returned examples."""
image_dir = tmp_path / "test"
image_dir.mkdir()
(image_dir / "unsafe_001.jpg").write_bytes(b"fake image")
(image_dir / "unsafe_002.jpg").write_bytes(b"fake image")

loader = _VLGuardDataset(subset=VLGuardSubset.UNSAFES, max_examples=1)

with patch.object(
loader,
"_download_dataset_files_async",
new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)),
):
dataset = await loader.fetch_dataset_async()

# max_examples=1 → 1 example × 2 prompts = 2 prompts
assert len(dataset.seeds) == 2

async def test_prompt_group_id_links_text_and_image(self, mock_vlguard_metadata, tmp_path):
"""Test that text and image prompts share the same prompt_group_id."""
image_dir = tmp_path / "test"
Expand Down
Loading