diff --git a/src/virtual_stain_flow/evaluation/evaluation_utils.py b/src/virtual_stain_flow/evaluation/evaluation_utils.py index 8d1d285..bb2e4a2 100644 --- a/src/virtual_stain_flow/evaluation/evaluation_utils.py +++ b/src/virtual_stain_flow/evaluation/evaluation_utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Any import numpy as np import pandas as pd @@ -10,13 +10,25 @@ from virtual_stain_flow.datasets.base_wrapper_dataset import BaseWrapperDataset +def _to_numpy_image(value: Any) -> np.ndarray: + if isinstance(value, torch.Tensor): + return value.detach().cpu().numpy() + return np.asarray(value) + + +def _normalize_to_list(sample: Any) -> List[np.ndarray]: + if isinstance(sample, (list, tuple)): + return [_to_numpy_image(item) for item in sample] + return [_to_numpy_image(sample)] + + def extract_samples_from_dataset( dataset: Union[BaseImageDataset, CropImageDataset, BaseWrapperDataset], indices: List[int], ) -> Tuple[ - List[np.ndarray], - List[np.ndarray], - Optional[List[np.ndarray]], + List[Union[np.ndarray, List[np.ndarray]]], + List[Union[np.ndarray, List[np.ndarray]]], + Optional[List[Union[np.ndarray, List[np.ndarray]]]], Optional[List[Tuple[int, int]]], ]: """ @@ -26,13 +38,15 @@ def extract_samples_from_dataset( (x, y) coordinates of each crop for visualization with bounding boxes. :param dataset: A BaseImageDataset or CropImageDataset instance. - :param indices: List of dataset indices to extract. - :return: Tuple of (inputs, targets, raw_images, patch_coords). - - inputs: List of numpy arrays, each with shape (C, H, W) or (H, W). - - targets: List of numpy arrays, each with shape (C, H, W) or (H, W). - - raw_images: List of numpy arrays for CropImageDataset (original uncropped images), - or None for BaseImageDataset. - - patch_coords: List of (x, y) tuples for CropImageDataset, or None for BaseImageDataset. + :param indices: List of dataset indices to extract. + :return: Tuple of (inputs, targets, raw_images, patch_coords). + - inputs: List of numpy arrays, each with shape (C, H, W) or (H, W). + Multi-input samples can be provided as a list of arrays per sample. + - targets: List of numpy arrays, each with shape (C, H, W) or (H, W). + Multi-target samples can be provided as a list of arrays per sample. + - raw_images: List of numpy arrays for CropImageDataset (original uncropped images), + or None for BaseImageDataset. + - patch_coords: List of (x, y) tuples for CropImageDataset, or None for BaseImageDataset. """ is_wrapper_dataset = False if isinstance(dataset, BaseWrapperDataset): @@ -55,9 +69,9 @@ def extract_samples_from_dataset( f"max index requested: {max(indices)}" ) - inputs: List[np.ndarray] = [] - targets: List[np.ndarray] = [] - raw_images: Optional[List[np.ndarray]] = [] if is_crop_dataset else None + inputs: List[Union[np.ndarray, List[np.ndarray]]] = [] + targets: List[Union[np.ndarray, List[np.ndarray]]] = [] + raw_images: Optional[List[Union[np.ndarray, List[np.ndarray]]]] = [] if is_crop_dataset else None patch_coords: Optional[List[Tuple[int, int]]] = [] if is_crop_dataset else None for idx in indices: @@ -65,15 +79,11 @@ def extract_samples_from_dataset( input_tensor, target_tensor = dataset[idx] # Convert to numpy - handle both Tensor and ndarray inputs - if isinstance(input_tensor, torch.Tensor): - inputs.append(input_tensor.numpy()) - else: - inputs.append(np.asarray(input_tensor)) - - if isinstance(target_tensor, torch.Tensor): - targets.append(target_tensor.numpy()) - else: - targets.append(np.asarray(target_tensor)) + input_list = _normalize_to_list(input_tensor) + target_list = _normalize_to_list(target_tensor) + + inputs.append(input_list[0] if len(input_list) == 1 else input_list) + targets.append(target_list[0] if len(target_list) == 1 else target_list) if is_crop_dataset: # Access the original uncropped image and crop coordinates diff --git a/src/virtual_stain_flow/evaluation/predict_utils.py b/src/virtual_stain_flow/evaluation/predict_utils.py index b0eea34..4e5829a 100644 --- a/src/virtual_stain_flow/evaluation/predict_utils.py +++ b/src/virtual_stain_flow/evaluation/predict_utils.py @@ -1,18 +1,26 @@ -from typing import Optional, List, Tuple, Callable +from typing import Optional, List, Tuple, Callable, Union, Any import torch import numpy as np from torch.utils.data import DataLoader, Dataset, Subset from albumentations import ImageOnlyTransform, Compose +def _move_to_device(value: Any, device: Union[str, torch.device]) -> Any: + if isinstance(value, torch.Tensor): + return value.to(device) + if isinstance(value, (list, tuple)): + return type(value)(_move_to_device(item, device) for item in value) + return value + + def predict_image( dataset: Dataset, model: torch.nn.Module, batch_size: int = 1, - device: str = "cpu", + device: Union[str, torch.device] = "cpu", num_workers: int = 0, - indices: Optional[List[int]] = None -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + indices: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]]: """ Runs a model on a dataset, performing a forward pass on all (or a subset of) input images in evaluation mode and returning a stacked tensor of predictions. @@ -27,6 +35,7 @@ def predict_image( :param indices: Optional list of dataset indices to subset the dataset before inference. :return: Tuple of stacked target, prediction, and input tensors. + For multi-input datasets, the third element is a list of stacked input tensors. """ # Subset the dataset if indices are provided if indices is not None: @@ -38,25 +47,41 @@ def predict_image( model.to(device) model.eval() - predictions, targets, inputs = [], [], [] + predictions, targets = [], [] + inputs: Union[List[torch.Tensor], List[List[torch.Tensor]]] = [] with torch.no_grad(): for input, target in dataloader: # Unpacking (input_tensor, target_tensor) - input = input.to(device) # Move input data to the specified device + input = _move_to_device(input, device) # Forward pass - prediction = model(input) - + if isinstance(input, (list, tuple)): + prediction = model(*input) + else: + prediction = model(input) + # output both target and prediction tensors for metric targets.append(target.cpu()) predictions.append(prediction.cpu()) # Move to CPU for stacking - inputs.append(input.cpu()) + + if isinstance(input, (list, tuple)): + if not inputs: + inputs = [[] for _ in range(len(input))] + for idx, item in enumerate(input): + inputs[idx].append(item.cpu()) + else: + inputs.append(input.cpu()) + + if inputs and isinstance(inputs[0], list): + inputs_stacked = [torch.cat(batch_list, dim=0) for batch_list in inputs] # type: ignore[arg-type] + else: + inputs_stacked = torch.cat(inputs, dim=0) # type: ignore[arg-type] return ( - torch.cat(targets, dim=0), - torch.cat(predictions, dim=0), - torch.cat(inputs, dim=0) - ) + torch.cat(targets, dim=0), + torch.cat(predictions, dim=0), + inputs_stacked, + ) def process_tensor_image( img_tensor: torch.Tensor, diff --git a/src/virtual_stain_flow/evaluation/visualization.py b/src/virtual_stain_flow/evaluation/visualization.py index 271f54e..08da34e 100644 --- a/src/virtual_stain_flow/evaluation/visualization.py +++ b/src/virtual_stain_flow/evaluation/visualization.py @@ -5,7 +5,7 @@ Supports both BaseImageDataset and CropImageDataset with optional metrics display. """ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Any import numpy as np import pandas as pd @@ -20,14 +20,56 @@ from .predict_utils import predict_image +def _to_numpy_image(value: Any) -> np.ndarray: + if isinstance(value, torch.Tensor): + return value.detach().cpu().numpy() + return np.asarray(value) + + +def _normalize_to_list(sample: Any) -> List[np.ndarray]: + if isinstance(sample, (list, tuple)): + return [_to_numpy_image(item) for item in sample] + return [_to_numpy_image(sample)] + + +def _split_channels(img: np.ndarray, channel_indices: Optional[List[int]]) -> List[np.ndarray]: + if img.ndim == 2: + if channel_indices is not None and any(idx != 0 for idx in channel_indices): + raise ValueError("Channel indices out of range for 2D image.") + return [img] + + if img.ndim == 3: + total_channels = img.shape[0] + indices = channel_indices if channel_indices is not None else list(range(total_channels)) + for idx in indices: + if idx < 0 or idx >= total_channels: + raise ValueError( + f"Channel index {idx} out of range for image with {total_channels} channels." + ) + return [img[idx] for idx in indices] + + raise ValueError(f"Unsupported image shape for visualization: {img.shape}") + + +def _split_sample_channels(sample: Any, channel_indices: Optional[List[int]]) -> List[np.ndarray]: + split_images: List[np.ndarray] = [] + for item in _normalize_to_list(sample): + split_images.extend(_split_channels(item, channel_indices)) + return split_images + + +def _build_titles(prefix: str, count: int) -> List[str]: + return [f"{prefix} {i + 1}" for i in range(count)] + + def plot_predictions_grid( - inputs: List[np.ndarray], - targets: List[np.ndarray], - predictions: Optional[List[np.ndarray]] = None, + inputs: List[Union[np.ndarray, List[np.ndarray]]], + targets: List[Union[np.ndarray, List[np.ndarray]]], + predictions: Optional[List[Union[np.ndarray, List[np.ndarray]]]] = None, *, sample_indices: Optional[List[int]] = None, row_label_prefix: str = "", - raw_images: Optional[List[np.ndarray]] = None, + raw_images: Optional[List[Union[np.ndarray, List[np.ndarray]]]] = None, patch_coords: Optional[List[Tuple[int, int]]] = None, metrics_df: Optional[pd.DataFrame] = None, save_path: Optional[str] = None, @@ -36,6 +78,10 @@ def plot_predictions_grid( show_plot: bool = True, wspace: float = 0.05, hspace: float = 0.15, + raw_channel_indices: Optional[List[int]] = None, + input_channel_indices: Optional[List[int]] = None, + target_channel_indices: Optional[List[int]] = None, + prediction_channel_indices: Optional[List[int]] = None, ) -> plt.Figure: """ Plot a grid of images comparing inputs, targets, and optionally predictions. @@ -47,8 +93,11 @@ def plot_predictions_grid( - [Prediction] (only if predictions provided, with optional metrics in title) :param inputs: List of input images, each (C, H, W) or (H, W). + Multi-input samples can be provided as a list of arrays per sample. :param targets: List of target images, each (C, H, W) or (H, W). + Multi-target samples can be provided as a list of arrays per sample. :param predictions: Optional list of prediction images, each (C, H, W) or (H, W). + Multi-output samples can be provided as a list of arrays per sample. If None, only inputs and targets are displayed. :param sample_indices: Optional list of indices to display as row labels. If None, uses 0-based sequential indices. @@ -66,6 +115,10 @@ def plot_predictions_grid( :param show_plot: Whether to display the plot (default: True). :param wspace: Horizontal spacing between subplots (default: 0.05). :param hspace: Vertical spacing between subplots (default: 0.15). + :param raw_channel_indices: Optional list of channel indices to display for raw images. + :param input_channel_indices: Optional list of channel indices to display for inputs. + :param target_channel_indices: Optional list of channel indices to display for targets. + :param prediction_channel_indices: Optional list of channel indices to display for predictions. """ num_samples = len(inputs) if num_samples == 0: @@ -84,24 +137,55 @@ def plot_predictions_grid( ) has_raw_images = raw_images is not None and len(raw_images) > 0 + if has_raw_images and len(raw_images) != num_samples: + raise ValueError( + f"Length mismatch: inputs ({num_samples}), raw_images ({len(raw_images)})" + ) + + if sample_indices is not None and len(sample_indices) != num_samples: + raise ValueError( + f"Length mismatch: inputs ({num_samples}), sample_indices ({len(sample_indices)})" + ) + + if has_predictions and prediction_channel_indices is None: + prediction_channel_indices = target_channel_indices - # Determine number of columns based on what's provided - # Base: Input, Target (2 cols) or Input, Target, Prediction (3 cols) - # With raw: Raw, Input, Target (3 cols) or Raw, Input, Target, Prediction (4 cols) - num_cols = 2 + (1 if has_raw_images else 0) + (1 if has_predictions else 0) + # Determine columns based on channel splits in the first sample + raw_first = _split_sample_channels(raw_images[0], raw_channel_indices) if has_raw_images else [] + input_first = _split_sample_channels(inputs[0], input_channel_indices) + target_first = _split_sample_channels(targets[0], target_channel_indices) + pred_first = ( + _split_sample_channels(predictions[0], prediction_channel_indices) + if has_predictions + else [] + ) + + if has_predictions and len(pred_first) != len(target_first): + raise ValueError( + "Target and prediction channel counts must match for paired display." + ) + + raw_titles = _build_titles("Raw Input", len(raw_first)) + input_titles = _build_titles("Input", len(input_first)) + target_titles = _build_titles("Target", len(target_first)) + pred_titles = _build_titles("Prediction", len(pred_first)) + + if has_predictions: + interleaved_titles = [ + title + for i in range(len(target_titles)) + for title in (target_titles[i], pred_titles[i]) + ] + else: + interleaved_titles = target_titles + + column_titles = raw_titles + input_titles + interleaved_titles + num_cols = len(column_titles) # Default sample indices if not provided if sample_indices is None: sample_indices = list(range(num_samples)) - # Column titles - build dynamically - column_titles = [] - if has_raw_images: - column_titles.append("Raw Input") - column_titles.extend(["Input", "Target"]) - if has_predictions: - column_titles.append("Prediction") - # Create figure fig_width = panel_width * num_cols fig_height = panel_width * num_samples @@ -112,13 +196,35 @@ def plot_predictions_grid( axes = axes.reshape(1, -1) for row_idx in range(num_samples): - # Build image set for this row dynamically - img_set = [] - if has_raw_images: - img_set.append(raw_images[row_idx]) - img_set.extend([inputs[row_idx], targets[row_idx]]) + raw_row = _split_sample_channels(raw_images[row_idx], raw_channel_indices) if has_raw_images else [] + input_row = _split_sample_channels(inputs[row_idx], input_channel_indices) + target_row = _split_sample_channels(targets[row_idx], target_channel_indices) + pred_row = ( + _split_sample_channels(predictions[row_idx], prediction_channel_indices) + if has_predictions + else [] + ) + + if has_predictions and len(pred_row) != len(target_row): + raise ValueError( + f"Row {row_idx} has mismatched target/prediction channel counts." + ) + if has_predictions: - img_set.append(predictions[row_idx]) + target_pred_row = [ + img + for i in range(len(target_row)) + for img in (target_row[i], pred_row[i]) + ] + else: + target_pred_row = target_row + + img_set = raw_row + input_row + target_pred_row + + if len(img_set) != num_cols: + raise ValueError( + f"Row {row_idx} has {len(img_set)} columns, expected {num_cols}." + ) for col_idx, img in enumerate(img_set): ax = axes[row_idx, col_idx] @@ -134,10 +240,10 @@ def plot_predictions_grid( ax.axis("off") # Draw bounding box on raw image - if has_raw_images and col_idx == 0 and patch_coords is not None: + if has_raw_images and col_idx < len(raw_titles) and patch_coords is not None: patch_x, patch_y = patch_coords[row_idx] # Infer patch size from target shape - target_shape = np.squeeze(targets[row_idx]).shape + target_shape = np.squeeze(target_row[0]).shape patch_h, patch_w = target_shape[-2], target_shape[-1] rect = Rectangle( (patch_x, patch_y), @@ -149,8 +255,8 @@ def plot_predictions_grid( ) ax.add_patch(rect) - # Row label on the Input column (first column if no raw images, second if raw images) - input_col_idx = 1 if has_raw_images else 0 + # Row label on the first input column + input_col_idx = len(raw_titles) if col_idx == input_col_idx: row_label = f"{row_label_prefix}{sample_indices[row_idx]}" # Determine text color based on top-left corner brightness @@ -175,7 +281,7 @@ def plot_predictions_grid( horizontalalignment="left", ) - # Metrics on prediction column (last column, only if predictions provided) + # Metrics on last prediction column (only if predictions provided) if has_predictions and metrics_df is not None and row_idx < len(metrics_df): metric_values = metrics_df.iloc[row_idx] metric_text = "\n".join( @@ -218,7 +324,8 @@ def plot_dataset_grid( :param indices: List of dataset indices to display. :param save_path: Optional path to save the figure. :param kwargs: Additional arguments passed to `plot_predictions_grid`. - Supported: row_label_prefix, cmap, panel_width, show_plot, wspace, hspace. + Supported: row_label_prefix, cmap, panel_width, show_plot, wspace, hspace, + raw_channel_indices, input_channel_indices, target_channel_indices, prediction_channel_indices. """ # Extract samples from dataset ( @@ -266,7 +373,8 @@ def plot_predictions_grid_from_model( :param device: Device for inference ("cpu" or "cuda"). :param save_path: Optional path to save the figure. :param kwargs: Additional arguments passed to `plot_predictions_grid`. - Supported: row_label_prefix, cmap, panel_width, show_plot, wspace, hspace. + Supported: row_label_prefix, cmap, panel_width, show_plot, wspace, hspace, + raw_channel_indices, input_channel_indices, target_channel_indices, prediction_channel_indices. """ # Step 1: Run inference targets_tensor, predictions_tensor, inputs_tensor = predict_image( @@ -282,7 +390,15 @@ def plot_predictions_grid_from_model( _, _, raw_images, patch_coords = extract_samples_from_dataset(dataset, indices) # use the collected input and target stack at prediction time instead of # re-extract - inputs, targets = inputs_tensor.numpy(), targets_tensor.numpy() + if isinstance(inputs_tensor, list): + inputs = [ + [inputs_tensor[i][row_idx].numpy() for i in range(len(inputs_tensor))] + for row_idx in range(len(indices)) + ] + else: + inputs = inputs_tensor.numpy() + + targets = targets_tensor.numpy() # Convert predictions tensor to list of numpy arrays predictions = [predictions_tensor[i].numpy() for i in range(len(indices))]