Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 14 additions & 93 deletions src/virtual_stain_flow/datasets/ds_engine/crop_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,98 +3,19 @@

Utilities for generating crop coordinates from BaseImageDataset objects.
Designed for easy creation of CropImageDataset instances.
Made Facade to account for increased complexity and future expansion.
"""

from typing import Dict, List, Tuple, Any, Protocol

from ..base_dataset import BaseImageDataset
from .ds_utils import (
_get_active_channels,
_validate_same_dimensions_across_channels
)

CropSpec = Tuple[Tuple[int, int], int, int]
CropMap = Dict[int, List[CropSpec]]


class CropGenerator(Protocol):
"""
Protocol for crop generator functions.
"""
def __call__(
self,
dataset: BaseImageDataset,
**kwargs: Any
) -> CropMap:
pass


def _compute_center_crop(
image_width: int,
image_height: int,
crop_size: int
) -> Tuple[int, int]:
"""
Compute top-left (x, y) coordinates for a center crop.

:param image_width: Width of the source image.
:param image_height: Height of the source image.
:param crop_size: Size of the square crop (width and height).
:return: Tuple of (x, y) for top-left corner of center crop.
:raises ValueError: If crop_size exceeds image dimensions.
"""
if crop_size > image_width or crop_size > image_height:
raise ValueError(
f"crop_size ({crop_size}) exceeds image dimensions "
f"({image_width}x{image_height})."
)

x = (image_width - crop_size) // 2
y = (image_height - crop_size) // 2
return x, y


def generate_center_crops(
dataset: BaseImageDataset,
crop_size: int,
) -> CropMap:
"""
Generate center crop coordinates for each sample in a BaseImageDataset.

:param dataset: A BaseImageDataset instance (or compatible object with
`file_state.manifest` attribute supporting `get_image_dimensions()`).
:param crop_size: Size of the square crop (same width and height).
:return: Dictionary mapping manifest indices to lists of crop specs.
Format: {manifest_idx: [((x, y), width, height), ...]}
:raises ValueError: If crop_size is non-positive, if no active channels
are configured, or if channel dimensions don't match for any sample.
"""
if crop_size <= 0:
raise ValueError(f"crop_size must be positive, got {crop_size}.")

active_channels = _get_active_channels(dataset)
if not active_channels:
raise ValueError(
"No active channels configured. Set input_channel_keys and/or "
"target_channel_keys on the dataset before generating crops."
)

manifest = dataset.file_state.manifest
crop_specs: Dict[int, List[Tuple[Tuple[int, int], int, int]]] = {}

for idx in range(len(dataset)):
# Get dimensions for all active channels
dims = manifest.get_image_dimensions(idx, channels=active_channels)

# Validate all channels have matching dimensions
width, height = _validate_same_dimensions_across_channels(
dims, active_channels, idx
)

# Compute center crop coordinates
x, y = _compute_center_crop(width, height, crop_size)

# Store as crop_specs format: {idx: [((x, y), w, h), ...]}
crop_specs[idx] = [((x, y), crop_size, crop_size)]

return crop_specs
from .crop_generators.protocol import CropSpec, CropMap, CropGenerator
from .crop_generators.center import generate_center_crops
from .crop_generators.point_centered import generate_point_centered_crops
from .crop_generators.tile import generate_tile_crops

__all__ = [
"CropSpec",
"CropMap",
"CropGenerator",
"generate_center_crops",
"generate_point_centered_crops",
"generate_tile_crops",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .protocol import CropMap, CropSpec, CropGenerator
from .center import generate_center_crops
from .point_centered import generate_point_centered_crops
from .tile import generate_tile_crops

__all__ = [
"CropMap",
"CropSpec",
"CropGenerator",
"generate_center_crops",
"generate_point_centered_crops",
"generate_tile_crops"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
"""

from typing import Dict, List, Tuple

from .protocol import CropMap
from ...base_dataset import BaseImageDataset
from ..ds_utils import (
_get_active_channels,
_validate_same_dimensions_across_channels
)


def _compute_center_crop(
image_width: int,
image_height: int,
crop_size: int
) -> Tuple[int, int]:
"""
Compute top-left (x, y) coordinates for a center crop.

:param image_width: Width of the source image.
:param image_height: Height of the source image.
:param crop_size: Size of the square crop (width and height).
:return: Tuple of (x, y) for top-left corner of center crop.
:raises ValueError: If crop_size exceeds image dimensions.
"""
if crop_size > image_width or crop_size > image_height:
raise ValueError(
f"crop_size ({crop_size}) exceeds image dimensions "
f"({image_width}x{image_height})."
)

x = (image_width - crop_size) // 2
y = (image_height - crop_size) // 2
return x, y


def generate_center_crops(
dataset: BaseImageDataset,
crop_size: int,
) -> CropMap:
"""
Generate center crop coordinates for each sample in a BaseImageDataset.

:param dataset: A BaseImageDataset instance (or compatible object with
`file_state.manifest` attribute supporting `get_image_dimensions()`).
:param crop_size: Size of the square crop (same width and height).
:return: Dictionary mapping manifest indices to lists of crop specs.
Format: {manifest_idx: [((x, y), width, height), ...]}
:raises ValueError: If crop_size is non-positive, if no active channels
are configured, or if channel dimensions don't match for any sample.
"""
if crop_size <= 0:
raise ValueError(f"crop_size must be positive, got {crop_size}.")

active_channels = _get_active_channels(dataset)
if not active_channels:
raise ValueError(
"No active channels configured. Set input_channel_keys and/or "
"target_channel_keys on the dataset before generating crops."
)

manifest = dataset.file_state.manifest
crop_specs: Dict[int, List[Tuple[Tuple[int, int], int, int]]] = {}

for idx in range(len(dataset)):
# Get dimensions for all active channels
dims = manifest.get_image_dimensions(idx, channels=active_channels)

# Validate all channels have matching dimensions
width, height = _validate_same_dimensions_across_channels(
dims, active_channels, idx
)

# Compute center crop coordinates
x, y = _compute_center_crop(width, height, crop_size)

# Store as crop_specs format: {idx: [((x, y), w, h), ...]}
crop_specs[idx] = [((x, y), crop_size, crop_size)]

return crop_specs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
"""

import numpy as np

from .protocol import CropMap
from ...base_dataset import BaseImageDataset
from ..ds_utils import (
_get_active_channels,
_validate_same_dimensions_across_channels
)


def _compute_point_centered_crops(
image_width: int,
image_height: int,
crop_size: int,
centers: dict[str, np.ndarray]
) -> list[tuple[int, int]]:
"""
Compute the top-left coordinates of crops centered around specified points,
ensuring that the crops fit fully within the image boundaries.

:param image_width: Width of the image.
:param image_height: Height of the image.
:param crop_size: Size of the crops to generate.
:param centers: Dictionary containing 'X' and 'Y' coordinates of the centers.
:return: List of (x, y) tuples representing the top-left coordinates of the crops.
"""

crops = []

for x, y in zip(centers['X'], centers['Y']):

if x <= crop_size // 2 or x >= image_width - crop_size // 2 or \
y <= crop_size // 2 or y >= image_height - crop_size // 2:
continue # Skip points too close to the edge for a full crop

crop_x = int(x - crop_size // 2)
crop_y = int(y - crop_size // 2)
crops.append((crop_x, crop_y))

return crops


def generate_point_centered_crops(
dataset: BaseImageDataset,
crop_size: int | None = None,
mapping: list[dict[str, np.ndarray]] | None = None,
) -> CropMap:
"""
Generate crop specifications centered around specified points
for each image in the dataset.

:param dataset: BaseImageDataset containing the images and metadata.
:param crop_size: Size of the crops to generate. Made optional
for better error handling when called from the CropImageDataset
.from_base_dataset class method.
:param mapping: List of dictionaries containing the centers for each image.
Each dictionary should have keys 'X' and 'Y' with corresponding numpy
arrays of coordinates.
Made optional for the same reason as crop_size.
:return: CropMap containing the generated crop specifications.
"""

if crop_size is None:
raise ValueError(
"crop_size must be provided for point-centered crop generation."
)
if mapping is None:
raise ValueError(
"mapping must be provided for point-centered crop generation."
)

if crop_size <= 0:
raise ValueError(f"crop_size must be positive, got {crop_size}.")

active_channels = _get_active_channels(dataset)
if not active_channels:
raise ValueError(
"No active channels configured. Set input_channel_keys and/or "
"target_channel_keys on the dataset before generating crops."
)

manifest = dataset.file_state.manifest
crop_specs: dict[int, list[tuple[tuple[int, int], int, int]]] = {}

for idx in range(len(dataset)):

dims = manifest.get_image_dimensions(idx, channels=active_channels)

width, height = _validate_same_dimensions_across_channels(
dims, active_channels, idx
)

centers = mapping[idx]

crop_lists = _compute_point_centered_crops(width, height, crop_size, centers)
crop_specs[idx] = [
(crop_coords, crop_size, crop_size) for crop_coords in crop_lists
]

return crop_specs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
protocol.py

Defines the CropGenerator protocol for crop generator functions.
"""

from typing import Dict, List, Tuple, Any, Protocol

from ...base_dataset import BaseImageDataset


CropSpec = Tuple[Tuple[int, int], int, int]
CropMap = Dict[int, List[CropSpec]]


class CropGenerator(Protocol):
"""
Protocol for crop generator functions.
"""
def __call__(
self,
dataset: BaseImageDataset,
**kwargs: Any
) -> CropMap:
pass
Loading
Loading