Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyrit/backend/mappers/attack_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

import logging
import mimetypes
import os
import time
import uuid
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Optional, cast
from urllib.parse import quote, urlparse

Expand Down Expand Up @@ -178,7 +178,7 @@ def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str]
if value.startswith(("http://", "https://", "data:")):
return value
# Local file path — construct a media endpoint URL
if os.path.isfile(value):
if Path(value).is_file():
return f"/api/media?path={quote(str(value))}"
return value

Expand Down Expand Up @@ -373,7 +373,7 @@ def _build_filename(
source = value
if source.startswith("http"):
source = urlparse(source).path
ext = os.path.splitext(source)[1] # e.g. ".png"
ext = Path(source).suffix # e.g. ".png"

if not ext:
# Fallback: guess from mime type based on data type prefix
Expand Down
32 changes: 16 additions & 16 deletions pyrit/backend/routes/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import logging
import mimetypes
import os
from pathlib import Path

from fastapi import APIRouter, HTTPException, Query
Expand Down Expand Up @@ -61,39 +60,40 @@
}


def _validate_media_path(*, path: str, allowed_root: str) -> str:
def _validate_media_path(*, path: str, allowed_root: Path) -> Path:
"""
Validate and sanitize a user-provided file path against an allowed root directory.

Uses ``os.path.realpath`` to resolve symlinks and ``..`` components, then
verifies the canonical path starts with the allowed root prefix. This is
the standard sanitization pattern recognized by static analysis tools
(e.g. CodeQL ``py/path-injection``).
Uses ``Path.resolve()`` to resolve symlinks and ``..`` components, then
verifies the canonical path is under the allowed root. This is the standard
sanitization pattern recognized by static analysis tools (e.g. CodeQL
``py/path-injection``).

Args:
path: The user-provided file path to validate.
allowed_root: The canonical (``realpath``-resolved) allowed root directory.
allowed_root: The canonical (``resolve``-d) allowed root directory.

Returns:
The canonical, validated file path.

Raises:
HTTPException 403: If the path fails any validation check.
"""
real_path = os.path.realpath(path)
allowed_prefix = allowed_root + os.sep
real_path = Path(path).resolve(strict=False)

if not real_path.startswith(allowed_prefix):
raise HTTPException(status_code=403, detail="Access denied: path is outside the allowed results directory.")
try:
relative_parts = real_path.relative_to(allowed_root).parts
except ValueError as exc:
raise HTTPException(
status_code=403, detail="Access denied: path is outside the allowed results directory."
) from exc

# Restrict to known media subdirectories (e.g. prompt-memory-entries/)
relative_parts = Path(os.path.relpath(real_path, allowed_root)).parts
if not relative_parts or relative_parts[0] not in _ALLOWED_SUBDIRECTORIES:
raise HTTPException(status_code=403, detail="Access denied: path is not in a media subdirectory.")

# Only allow known media file extensions
_, ext = os.path.splitext(real_path)
if ext.lower() not in _ALLOWED_EXTENSIONS:
if real_path.suffix.lower() not in _ALLOWED_EXTENSIONS:
raise HTTPException(status_code=403, detail="Access denied: file type is not allowed.")

return real_path
Expand Down Expand Up @@ -125,13 +125,13 @@ async def serve_media_async(
memory = CentralMemory.get_memory_instance()
if not memory.results_path:
raise HTTPException(status_code=500, detail="Memory results_path is not configured.")
allowed_root = os.path.realpath(memory.results_path)
allowed_root = Path(memory.results_path).resolve(strict=False)
except Exception as exc:
raise HTTPException(status_code=500, detail="Memory not initialized; cannot determine results path.") from exc

validated_path = _validate_media_path(path=path, allowed_root=allowed_root)

if not os.path.isfile(validated_path):
if not validated_path.is_file():
raise HTTPException(status_code=404, detail="File not found.")

mime_type, _ = mimetypes.guess_type(validated_path)
Expand Down
7 changes: 3 additions & 4 deletions pyrit/models/storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import annotations

import logging
import os
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -110,7 +109,7 @@ async def path_exists(self, path: Union[Path, str]) -> bool:

"""
path = self._convert_to_path(path)
return os.path.exists(path)
return path.exists()

async def is_file(self, path: Union[Path, str]) -> bool:
"""
Expand All @@ -124,7 +123,7 @@ async def is_file(self, path: Union[Path, str]) -> bool:

"""
path = self._convert_to_path(path)
return os.path.isfile(path)
return path.is_file()

async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None:
"""
Expand All @@ -136,7 +135,7 @@ async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None:
"""
directory_path = self._convert_to_path(path)
if not directory_path.exists():
os.makedirs(directory_path, exist_ok=True)
directory_path.mkdir(parents=True, exist_ok=True)

def _convert_to_path(self, path: Union[Path, str]) -> Path:
"""
Expand Down
31 changes: 17 additions & 14 deletions pyrit/output/conversation/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import logging
import os
from pathlib import Path

from pyrit.models import Message, MessagePiece, Score
from pyrit.output.conversation.base import ConversationPrinterBase
Expand Down Expand Up @@ -224,11 +225,13 @@ def _format_image_content(self, *, image_path: str) -> list[str]:
@staticmethod
def _format_link_path(path: str) -> str:
Comment thread
romanlutz marked this conversation as resolved.
"""Return a markdown-friendly link (POSIX separators, relative if possible)."""
path_obj = Path(path)
try:
relative_path = os.path.relpath(path)
relative_path = str(path_obj.relative_to(Path.cwd()))
except ValueError:
# Different mount/drive than cwd (Windows). Fall back to the absolute path.
relative_path = os.path.abspath(path)
# Path is not under cwd (different drive on Windows, or simply outside cwd).
# Fall back to the absolute path.
relative_path = str(path_obj.resolve())
return relative_path.replace("\\", "/")

def _maybe_blur_image_on_disk(self, *, image_path: str) -> str | None:
Expand All @@ -251,30 +254,30 @@ def _maybe_blur_image_on_disk(self, *, image_path: str) -> str | None:
str | None: The path to the blurred image, or ``None`` on failure.
"""
try:
blurred_path = self._blurred_destination(image_path=image_path)
if os.path.exists(blurred_path):
blurred_path = Path(self._blurred_destination(image_path=image_path))
if blurred_path.exists():
logger.debug(f"Reusing cached blurred image at {blurred_path}")
return blurred_path
return str(blurred_path)

os.makedirs(os.path.dirname(blurred_path) or ".", exist_ok=True)
blurred_path.parent.mkdir(parents=True, exist_ok=True)

from pyrit.output._image_utils import blur_image_bytes

with open(image_path, "rb") as f:
original_bytes = f.read()
blurred_bytes = blur_image_bytes(image_bytes=original_bytes, radius=self._blur_radius)

temp_path = f"{blurred_path}.tmp.{os.getpid()}"
temp_path = blurred_path.parent / f"{blurred_path.name}.tmp.{os.getpid()}"
try:
with open(temp_path, "wb") as f:
f.write(blurred_bytes)
os.replace(temp_path, blurred_path)
except Exception:
if os.path.exists(temp_path):
if temp_path.exists():
with contextlib.suppress(OSError):
os.remove(temp_path)
temp_path.unlink()
raise
return blurred_path
return str(blurred_path)
except Exception as exc:
logger.warning(f"Failed to write blurred image for {image_path}; falling back to a text link. Error: {exc}")
return None
Expand All @@ -289,9 +292,9 @@ def _blurred_destination(self, *, image_path: str) -> str:
Returns:
str: Path to the blurred file (sibling by default, or under ``blurred_dir``).
"""
directory = self._blurred_dir if self._blurred_dir is not None else os.path.dirname(image_path)
stem = os.path.splitext(os.path.basename(image_path))[0]
return os.path.join(directory, f"{stem}_blurred.png")
image_path_obj = Path(image_path)
directory = Path(self._blurred_dir) if self._blurred_dir is not None else image_path_obj.parent
return str(directory / f"{image_path_obj.stem}_blurred.png")

def _format_audio_content(self, *, audio_path: str) -> list[str]:
"""
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/backend/test_media_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_rejects_path_outside_results_directory(self, client: TestClient, _mock_
try:
response = client.get("/api/media", params={"path": outside_path})
assert response.status_code == 403
assert "outside the allowed results directory" in response.json()["detail"]
finally:
os.unlink(outside_path)

Expand All @@ -69,6 +70,34 @@ def test_rejects_path_traversal(self, client: TestClient, _mock_memory: Path) ->
traversal_path = str(_mock_memory / ".." / ".." / "etc" / "passwd")
response = client.get("/api/media", params={"path": traversal_path})
assert response.status_code == 403
assert "outside the allowed results directory" in response.json()["detail"]

def test_rejects_symlink_pointing_outside_results(self, client: TestClient, _mock_memory: Path) -> None:
"""A symlink under an allowed subdirectory that points outside the results dir is rejected.

``Path.resolve()`` resolves symlinks, so a symlink that targets a file outside the
allowed root must be rejected just like a plain path traversal attempt.
"""
# Create a file outside the allowed directory
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
tmp.write(b"\x89PNG\r\n\x1a\n")
outside_target = tmp.name

try:
symlink_path = _mock_memory / "prompt-memory-entries" / "evil_symlink.png"
try:
os.symlink(outside_target, symlink_path)
except (OSError, NotImplementedError) as exc:
# Symlink creation may fail on Windows without admin / developer mode.
pytest.skip(f"Cannot create symlink in this environment: {exc}")
Comment thread
romanlutz marked this conversation as resolved.

response = client.get("/api/media", params={"path": str(symlink_path)})
assert response.status_code == 403
Comment thread
romanlutz marked this conversation as resolved.
# Confirm the rejection reason is the symlink-escape check specifically,
# not one of the other 403 paths (subdirectory / extension).
assert "outside the allowed results directory" in response.json()["detail"]
finally:
os.unlink(outside_target)

def test_returns_404_for_nonexistent_file(self, client: TestClient, _mock_memory: Path) -> None:
"""Non-existent files under allowed subdirectory return 404."""
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/models/test_storage_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,30 @@ async def test_disk_storage_io_path_exists():
storage = DiskStorageIO()
path = "sample.txt"

with patch("os.path.exists", return_value=True) as mock_exists:
with patch("pathlib.Path.exists", return_value=True) as mock_exists:
result = await storage.path_exists(path)
assert result is True
mock_exists.assert_called_once_with(Path(path))
mock_exists.assert_called_once()


async def test_disk_storage_io_is_file():
storage = DiskStorageIO()
path = "sample.txt"

with patch("os.path.isfile", return_value=True) as mock_isfile:
with patch("pathlib.Path.is_file", return_value=True) as mock_isfile:
result = await storage.is_file(path)
assert result is True
mock_isfile.assert_called_once_with(Path(path))
mock_isfile.assert_called_once()


async def test_disk_storage_io_create_directory_if_not_exists():
storage = DiskStorageIO()
directory_path = "sample_dir"

with patch("os.makedirs") as mock_mkdir, patch("pathlib.Path.exists", return_value=False) as mock_exists:
with patch("pathlib.Path.mkdir") as mock_mkdir, patch("pathlib.Path.exists", return_value=False) as mock_exists:
await storage.create_directory_if_not_exists(directory_path)
mock_exists.assert_called_once()
mock_mkdir.assert_called_once_with(Path(directory_path), exist_ok=True)
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)


async def test_azure_blob_storage_io_read_file(azure_blob_storage_io):
Expand Down
Loading
Loading