diff --git a/src/virtual_stain_flow/datasets/ds_engine/crop_generator.py b/src/virtual_stain_flow/datasets/ds_engine/crop_generator.py index ad5cad1..bec4f5f 100644 --- a/src/virtual_stain_flow/datasets/ds_engine/crop_generator.py +++ b/src/virtual_stain_flow/datasets/ds_engine/crop_generator.py @@ -3,98 +3,19 @@ Utilities for generating crop coordinates from BaseImageDataset objects. Designed for easy creation of CropImageDataset instances. +Made Facade to account for increased complexity and future expansion. """ -from typing import Dict, List, Tuple, Any, Protocol - -from ..base_dataset import BaseImageDataset -from .ds_utils import ( - _get_active_channels, - _validate_same_dimensions_across_channels -) - -CropSpec = Tuple[Tuple[int, int], int, int] -CropMap = Dict[int, List[CropSpec]] - - -class CropGenerator(Protocol): - """ - Protocol for crop generator functions. - """ - def __call__( - self, - dataset: BaseImageDataset, - **kwargs: Any - ) -> CropMap: - pass - - -def _compute_center_crop( - image_width: int, - image_height: int, - crop_size: int -) -> Tuple[int, int]: - """ - Compute top-left (x, y) coordinates for a center crop. - - :param image_width: Width of the source image. - :param image_height: Height of the source image. - :param crop_size: Size of the square crop (width and height). - :return: Tuple of (x, y) for top-left corner of center crop. - :raises ValueError: If crop_size exceeds image dimensions. - """ - if crop_size > image_width or crop_size > image_height: - raise ValueError( - f"crop_size ({crop_size}) exceeds image dimensions " - f"({image_width}x{image_height})." - ) - - x = (image_width - crop_size) // 2 - y = (image_height - crop_size) // 2 - return x, y - - -def generate_center_crops( - dataset: BaseImageDataset, - crop_size: int, -) -> CropMap: - """ - Generate center crop coordinates for each sample in a BaseImageDataset. - - :param dataset: A BaseImageDataset instance (or compatible object with - `file_state.manifest` attribute supporting `get_image_dimensions()`). - :param crop_size: Size of the square crop (same width and height). - :return: Dictionary mapping manifest indices to lists of crop specs. - Format: {manifest_idx: [((x, y), width, height), ...]} - :raises ValueError: If crop_size is non-positive, if no active channels - are configured, or if channel dimensions don't match for any sample. - """ - if crop_size <= 0: - raise ValueError(f"crop_size must be positive, got {crop_size}.") - - active_channels = _get_active_channels(dataset) - if not active_channels: - raise ValueError( - "No active channels configured. Set input_channel_keys and/or " - "target_channel_keys on the dataset before generating crops." - ) - - manifest = dataset.file_state.manifest - crop_specs: Dict[int, List[Tuple[Tuple[int, int], int, int]]] = {} - - for idx in range(len(dataset)): - # Get dimensions for all active channels - dims = manifest.get_image_dimensions(idx, channels=active_channels) - - # Validate all channels have matching dimensions - width, height = _validate_same_dimensions_across_channels( - dims, active_channels, idx - ) - - # Compute center crop coordinates - x, y = _compute_center_crop(width, height, crop_size) - - # Store as crop_specs format: {idx: [((x, y), w, h), ...]} - crop_specs[idx] = [((x, y), crop_size, crop_size)] - - return crop_specs +from .crop_generators.protocol import CropSpec, CropMap, CropGenerator +from .crop_generators.center import generate_center_crops +from .crop_generators.point_centered import generate_point_centered_crops +from .crop_generators.tile import generate_tile_crops + +__all__ = [ + "CropSpec", + "CropMap", + "CropGenerator", + "generate_center_crops", + "generate_point_centered_crops", + "generate_tile_crops", +] diff --git a/src/virtual_stain_flow/datasets/ds_engine/crop_generators/__init__.py b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/__init__.py new file mode 100644 index 0000000..a012d18 --- /dev/null +++ b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/__init__.py @@ -0,0 +1,13 @@ +from .protocol import CropMap, CropSpec, CropGenerator +from .center import generate_center_crops +from .point_centered import generate_point_centered_crops +from .tile import generate_tile_crops + +__all__ = [ + "CropMap", + "CropSpec", + "CropGenerator", + "generate_center_crops", + "generate_point_centered_crops", + "generate_tile_crops" +] diff --git a/src/virtual_stain_flow/datasets/ds_engine/crop_generators/center.py b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/center.py new file mode 100644 index 0000000..7831847 --- /dev/null +++ b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/center.py @@ -0,0 +1,82 @@ +""" +""" + +from typing import Dict, List, Tuple + +from .protocol import CropMap +from ...base_dataset import BaseImageDataset +from ..ds_utils import ( + _get_active_channels, + _validate_same_dimensions_across_channels +) + + +def _compute_center_crop( + image_width: int, + image_height: int, + crop_size: int +) -> Tuple[int, int]: + """ + Compute top-left (x, y) coordinates for a center crop. + + :param image_width: Width of the source image. + :param image_height: Height of the source image. + :param crop_size: Size of the square crop (width and height). + :return: Tuple of (x, y) for top-left corner of center crop. + :raises ValueError: If crop_size exceeds image dimensions. + """ + if crop_size > image_width or crop_size > image_height: + raise ValueError( + f"crop_size ({crop_size}) exceeds image dimensions " + f"({image_width}x{image_height})." + ) + + x = (image_width - crop_size) // 2 + y = (image_height - crop_size) // 2 + return x, y + + +def generate_center_crops( + dataset: BaseImageDataset, + crop_size: int, +) -> CropMap: + """ + Generate center crop coordinates for each sample in a BaseImageDataset. + + :param dataset: A BaseImageDataset instance (or compatible object with + `file_state.manifest` attribute supporting `get_image_dimensions()`). + :param crop_size: Size of the square crop (same width and height). + :return: Dictionary mapping manifest indices to lists of crop specs. + Format: {manifest_idx: [((x, y), width, height), ...]} + :raises ValueError: If crop_size is non-positive, if no active channels + are configured, or if channel dimensions don't match for any sample. + """ + if crop_size <= 0: + raise ValueError(f"crop_size must be positive, got {crop_size}.") + + active_channels = _get_active_channels(dataset) + if not active_channels: + raise ValueError( + "No active channels configured. Set input_channel_keys and/or " + "target_channel_keys on the dataset before generating crops." + ) + + manifest = dataset.file_state.manifest + crop_specs: Dict[int, List[Tuple[Tuple[int, int], int, int]]] = {} + + for idx in range(len(dataset)): + # Get dimensions for all active channels + dims = manifest.get_image_dimensions(idx, channels=active_channels) + + # Validate all channels have matching dimensions + width, height = _validate_same_dimensions_across_channels( + dims, active_channels, idx + ) + + # Compute center crop coordinates + x, y = _compute_center_crop(width, height, crop_size) + + # Store as crop_specs format: {idx: [((x, y), w, h), ...]} + crop_specs[idx] = [((x, y), crop_size, crop_size)] + + return crop_specs diff --git a/src/virtual_stain_flow/datasets/ds_engine/crop_generators/point_centered.py b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/point_centered.py new file mode 100644 index 0000000..2d82f32 --- /dev/null +++ b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/point_centered.py @@ -0,0 +1,103 @@ +""" +""" + +import numpy as np + +from .protocol import CropMap +from ...base_dataset import BaseImageDataset +from ..ds_utils import ( + _get_active_channels, + _validate_same_dimensions_across_channels +) + + +def _compute_point_centered_crops( + image_width: int, + image_height: int, + crop_size: int, + centers: dict[str, np.ndarray] +) -> list[tuple[int, int]]: + """ + Compute the top-left coordinates of crops centered around specified points, + ensuring that the crops fit fully within the image boundaries. + + :param image_width: Width of the image. + :param image_height: Height of the image. + :param crop_size: Size of the crops to generate. + :param centers: Dictionary containing 'X' and 'Y' coordinates of the centers. + :return: List of (x, y) tuples representing the top-left coordinates of the crops. + """ + + crops = [] + + for x, y in zip(centers['X'], centers['Y']): + + if x <= crop_size // 2 or x >= image_width - crop_size // 2 or \ + y <= crop_size // 2 or y >= image_height - crop_size // 2: + continue # Skip points too close to the edge for a full crop + + crop_x = int(x - crop_size // 2) + crop_y = int(y - crop_size // 2) + crops.append((crop_x, crop_y)) + + return crops + + +def generate_point_centered_crops( + dataset: BaseImageDataset, + crop_size: int | None = None, + mapping: list[dict[str, np.ndarray]] | None = None, +) -> CropMap: + """ + Generate crop specifications centered around specified points + for each image in the dataset. + + :param dataset: BaseImageDataset containing the images and metadata. + :param crop_size: Size of the crops to generate. Made optional + for better error handling when called from the CropImageDataset + .from_base_dataset class method. + :param mapping: List of dictionaries containing the centers for each image. + Each dictionary should have keys 'X' and 'Y' with corresponding numpy + arrays of coordinates. + Made optional for the same reason as crop_size. + :return: CropMap containing the generated crop specifications. + """ + + if crop_size is None: + raise ValueError( + "crop_size must be provided for point-centered crop generation." + ) + if mapping is None: + raise ValueError( + "mapping must be provided for point-centered crop generation." + ) + + if crop_size <= 0: + raise ValueError(f"crop_size must be positive, got {crop_size}.") + + active_channels = _get_active_channels(dataset) + if not active_channels: + raise ValueError( + "No active channels configured. Set input_channel_keys and/or " + "target_channel_keys on the dataset before generating crops." + ) + + manifest = dataset.file_state.manifest + crop_specs: dict[int, list[tuple[tuple[int, int], int, int]]] = {} + + for idx in range(len(dataset)): + + dims = manifest.get_image_dimensions(idx, channels=active_channels) + + width, height = _validate_same_dimensions_across_channels( + dims, active_channels, idx + ) + + centers = mapping[idx] + + crop_lists = _compute_point_centered_crops(width, height, crop_size, centers) + crop_specs[idx] = [ + (crop_coords, crop_size, crop_size) for crop_coords in crop_lists + ] + + return crop_specs diff --git a/src/virtual_stain_flow/datasets/ds_engine/crop_generators/protocol.py b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/protocol.py new file mode 100644 index 0000000..374dc91 --- /dev/null +++ b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/protocol.py @@ -0,0 +1,25 @@ +""" +protocol.py + +Defines the CropGenerator protocol for crop generator functions. +""" + +from typing import Dict, List, Tuple, Any, Protocol + +from ...base_dataset import BaseImageDataset + + +CropSpec = Tuple[Tuple[int, int], int, int] +CropMap = Dict[int, List[CropSpec]] + + +class CropGenerator(Protocol): + """ + Protocol for crop generator functions. + """ + def __call__( + self, + dataset: BaseImageDataset, + **kwargs: Any + ) -> CropMap: + pass diff --git a/src/virtual_stain_flow/datasets/ds_engine/crop_generators/tile.py b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/tile.py new file mode 100644 index 0000000..0c37e8d --- /dev/null +++ b/src/virtual_stain_flow/datasets/ds_engine/crop_generators/tile.py @@ -0,0 +1,119 @@ +""" +Crop generator module for creating non-overlapping tile crops centered within +images in a BaseImageDataset. Provides a best-effort tiling approach that +maximizes the number of full tiles while centering the grid within the image +boundaries. +""" + +from typing import List + +from .protocol import CropMap, CropSpec +from ...base_dataset import BaseImageDataset +from ..ds_utils import ( + _get_active_channels, + _validate_same_dimensions_across_channels +) + + +def _compute_centered_tile_positions( + image_extent: int, + crop_size: int +) -> List[int]: + """ + Compute 1D tile start positions for non-overlapping tiles centered + within an image extent. + + :param image_extent: Size of image extent along one axis (width or height). + :param crop_size: Size of one tile along the same axis. + :return: List of start positions for each tile. + :raises ValueError: If crop_size exceeds image extent. + """ + n_tiles = image_extent // crop_size + if n_tiles == 0: + raise ValueError( + f"crop_size ({crop_size}) exceeds image dimensions " + f"along one axis ({image_extent})." + ) + + covered_extent = n_tiles * crop_size + margin = image_extent - covered_extent + start = margin // 2 + + return [start + (tile_idx * crop_size) for tile_idx in range(n_tiles)] + + +def _compute_centered_tile_crops( + image_width: int, + image_height: int, + crop_size: int +) -> List[CropSpec]: + """ + Compute non-overlapping square tiles arranged on a centered grid. + + :param image_width: Width of the source image. + :param image_height: Height of the source image. + :param crop_size: Size of square tile crop. + :return: List of crop specs in format ((x, y), width, height). + :raises ValueError: If crop_size exceeds image width or height. + """ + if crop_size > image_width or crop_size > image_height: + raise ValueError( + f"crop_size ({crop_size}) exceeds image dimensions " + f"({image_width}x{image_height})." + ) + + x_positions = _compute_centered_tile_positions(image_width, crop_size) + y_positions = _compute_centered_tile_positions(image_height, crop_size) + + return [ + ((x, y), crop_size, crop_size) + for y in y_positions + for x in x_positions + ] + + +def generate_tile_crops( + dataset: BaseImageDataset, + crop_size: int, +) -> CropMap: + """ + Generate best-effort centered, non-overlapping tiling crops + for each sample in a BaseImageDataset. + + Tiling is "best-effort" in that it uses as many full, non-overlapping + tiles of size `crop_size` as possible along each axis, then centers + the tile grid within the remaining field of view. + + :param dataset: A BaseImageDataset instance (or compatible object with + `file_state.manifest` attribute supporting `get_image_dimensions()`). + :param crop_size: Size of the square crop tile (same width and height). + :return: Dictionary mapping manifest indices to lists of crop specs. + Format: {manifest_idx: [((x, y), width, height), ...]} + :raises ValueError: If crop_size is non-positive, if no active channels + are configured, or if channel dimensions don't match for any sample. + """ + if crop_size <= 0: + raise ValueError(f"crop_size must be positive, got {crop_size}.") + + active_channels = _get_active_channels(dataset) + if not active_channels: + raise ValueError( + "No active channels configured. Set input_channel_keys and/or " + "target_channel_keys on the dataset before generating crops." + ) + + manifest = dataset.file_state.manifest + crop_specs: CropMap = {} + + for idx in range(len(dataset)): + dims = manifest.get_image_dimensions(idx, channels=active_channels) + + width, height = _validate_same_dimensions_across_channels( + dims, active_channels, idx + ) + + crop_specs[idx] = _compute_centered_tile_crops( + width, height, crop_size + ) + + return crop_specs