diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2f2c740616..18cbd2f67d 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -4,19 +4,24 @@ from spikeinterface.core import BaseRecording from spikeinterface.core.base import base_period_dtype +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_random_data_chunks +from spikeinterface.core.node_pipeline import PeakDetector, run_node_pipeline, PipelineNode from spikeinterface.preprocessing.rectify import RectifyRecording from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels, get_random_data_chunks -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording artifact_dtype = base_period_dtype -# this will be extend with channel boundaries if needed -# extended_artifact_dtype = artifact_dtype + [ -# # TODO -# ] +signed_artifact_dtype = np.dtype(artifact_dtype + [("sign", "U8")]) + + +def _indent_docstring(docstring: str, indent: int = 4) -> str: + indent_str = " " * indent + indented_lines = [(indent_str + line) if line.strip() else line for line in docstring.splitlines()] + return "\n".join(indented_lines) def _collapse_events(events: np.ndarray) -> np.ndarray: @@ -77,6 +82,7 @@ def __init__( saturation_threshold_uV: float, diff_threshold_uV: float | None, proportion: float, + signed: bool, ) -> None: """ Parameters @@ -91,6 +97,11 @@ def __init__( proportion : float Fraction of channels that must exceed the threshold for a sample to be labelled as saturated (0 < proportion < 1). + signed : bool + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``. """ PipelineNode.__init__(self, recording, return_output=True) @@ -102,6 +113,7 @@ def __init__( # slightly lower than the documented saturation point of the probe self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion + self.signed = signed self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() @@ -122,6 +134,27 @@ def get_dtype(self) -> np.dtype: """Return the NumPy dtype of the output array produced by :meth:`compute`.""" return self._dtype + def detect_in_chunk(self, traces, saturation_threshold, diff_threshold, proportion) -> np.ndarray: + saturation = np.mean(traces > saturation_threshold, axis=1) + detected_by_value = saturation > proportion + + if diff_threshold is not None: + # then compute the derivative of the voltage saturation + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= diff_threshold, axis=1) + + # Note this means the velocity is not checked for the last sample in the + # check because we are taking the forward derivative + n_diff_saturated = np.r_[n_diff_saturated, 0] + + # if either of those reaches more than the proportion of channels labels the sample as saturated + detected_by_diff = n_diff_saturated > proportion + saturation = np.logical_or(detected_by_value, detected_by_diff) + else: + saturation = detected_by_value + + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) + return intervals + def compute( self, traces: np.ndarray, @@ -165,33 +198,62 @@ def compute( # cast to float32 to prevent overflow when applying thresholds in unscaled ADC units traces = traces.astype("float32") - saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - detected_by_value = saturation > self.proportion - - if self.diff_threshold_unscaled is not None: - # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.diff_threshold_unscaled, axis=1) - - # Note this means the velocity is not checked for the last sample in the - # check because we are taking the forward derivative - n_diff_saturated = np.r_[n_diff_saturated, 0] - - # if either of those reaches more than the proportion of channels labels the sample as saturated - detected_by_diff = n_diff_saturated > self.proportion - saturation = np.logical_or(detected_by_value, detected_by_diff) + if not self.signed: + traces = np.abs(traces) + intervals = self.detect_in_chunk( + traces, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=artifact_dtype) + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index else: - saturation = detected_by_value + all_events = [] + for sign in (1, -1): + traces_signed = sign * traces + intervals = self.detect_in_chunk( + traces_signed, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=signed_artifact_dtype) + events["sign"] = "positive" if sign == 1 else "negative" + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index + all_events.append(events) + all_events = np.concatenate(all_events) + # sort by start sample index + order = np.argsort(all_events["start_sample_index"]) + events = all_events[order] - intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) - n_events = len(intervals) // 2 # Number of saturation periods - events = np.zeros(n_events, dtype=artifact_dtype) + return (events,) - for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): - events[i]["start_sample_index"] = start + start_frame - events[i]["end_sample_index"] = stop + start_frame - events[i]["segment_index"] = segment_index - return (events,) +_detect_saturation_periods_params = """saturation_threshold_uV : float | None, default: None + Voltage saturation threshold in μV. The appropriate value depends on + the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL + recommend **1200 μV**. NP2 probes are harder to saturate than NP1. + If ``None``, the value is read from the ``"saturation_threshold_uV"`` + annotation of ``recording``. +diff_threshold_uV : float | None, default: None + First-derivative threshold in μV/sample. Periods where the + sample-to-sample voltage change exceeds this value in the required + fraction of channels are flagged as saturation. Pass ``None`` to + disable derivative-based detection and rely solely on + ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. +proportion : float, default: 0.2 + Fraction of channels (0 < proportion < 1) that must exceed the + threshold for a sample to be considered saturated. +signed : bool, default: False + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``.""" def detect_saturation_periods( @@ -199,6 +261,7 @@ def detect_saturation_periods( saturation_threshold_uV: float | None = None, diff_threshold_uV: float | None = None, proportion: float = 0.2, + signed: bool = False, job_kwargs: dict | None = None, ) -> np.ndarray: """ @@ -224,21 +287,7 @@ def detect_saturation_periods( ---------- recording : BaseRecording The recording on which to detect saturation events. - saturation_threshold_uV : float | None, default: None - Voltage saturation threshold in μV. The appropriate value depends on - the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL - recommend **1200 μV**. NP2 probes are harder to saturate than NP1. - If ``None``, the value is read from the ``"saturation_threshold_uV"`` - annotation of ``recording``. - diff_threshold_uV : float | None, default: None - First-derivative threshold in μV/sample. Periods where the - sample-to-sample voltage change exceeds this value in the required - fraction of channels are flagged as saturation. Pass ``None`` to - disable derivative-based detection and rely solely on - ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. - proportion : float, default: 0.2 - Fraction of channels (0 < proportion < 1) that must exceed the - threshold for a sample to be considered saturated. + {} job_kwargs : dict | None, default: None Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. ``n_jobs``, ``chunk_duration``). @@ -270,6 +319,7 @@ def detect_saturation_periods( saturation_threshold_uV=saturation_threshold_uV, diff_threshold_uV=diff_threshold_uV, proportion=proportion, + signed=signed, ) saturation_periods = run_node_pipeline( @@ -278,6 +328,9 @@ def detect_saturation_periods( return _collapse_events(saturation_periods) +detect_saturation_periods.__doc__ = detect_saturation_periods.__doc__.format(_detect_saturation_periods_params) + + ## detect_artifact_periods_by_envelope zone class _DetectThresholdCrossing(PeakDetector): """ @@ -382,6 +435,23 @@ def compute( return (threshold_crossings,) +_detect_artifacts_by_envelope_params = """detect_threshold : float, default: 5 + Detection threshold as a multiple of the estimated per-channel noise + level of the envelope. +freq_max : float, default: 20.0 + Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the + rectified signal when building the envelope. +seed : int | None, default: None + Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. + If ``None``, ``get_noise_levels`` uses ``seed=0``. +job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). +random_slices_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the ``random_slices_kwargs`` + argument of :func:`~spikeinterface.core.get_noise_levels`.""" + + def detect_artifact_periods_by_envelope( recording: BaseRecording, detect_threshold: float = 5, @@ -410,21 +480,7 @@ def detect_artifact_periods_by_envelope( ---------- recording : BaseRecording The recording extractor from which to detect artefact periods. - detect_threshold : float, default: 5 - Detection threshold as a multiple of the estimated per-channel noise - level of the envelope. - freq_max : float, default: 20.0 - Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the - rectified signal when building the envelope. - seed : int | None, default: None - Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. - If ``None``, ``get_noise_levels`` uses ``seed=0``. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). - random_slices_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the ``random_slices_kwargs`` - argument of :func:`~spikeinterface.core.get_noise_levels`. + {} return_envelope : bool, default: False If ``True``, also return the intermediate envelope recording so that it can be inspected or plotted. @@ -475,7 +531,6 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] artifacts = _collapse_events(artifacts) if return_envelope: @@ -484,6 +539,11 @@ def detect_artifact_periods_by_envelope( return artifacts +detect_artifact_periods_by_envelope.__doc__ = detect_artifact_periods_by_envelope.__doc__.format( + _detect_artifacts_by_envelope_params +) + + def _transform_internal_dtype_to_artifact_dtype( artifacts: np.ndarray, recording: BaseRecording, @@ -561,35 +621,41 @@ def detect_artifact_periods( job_kwargs: dict | None = None, ) -> np.ndarray: """ - Detect artifact periods using one of several available methods. + Detect artifact periods using one of several available methods. - Available methods: + Available methods: - * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified - channel envelope. - * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. - See the documentation of each sub-function for a full description of their - parameters, which can be forwarded via ``method_kwargs``. + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. - Parameters - ---------- - recording : BaseRecording - The recording on which to detect artifact periods. - method : {"envelope", "saturation"}, default: "envelope" - Detection method to use. - method_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the selected detection - function. Pass ``None`` to use that function's defaults. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). + Method-specific parameters include: - Returns - ------- - np.ndarray - Array with dtype ``artifact_dtype`` describing each detected artifact - period. + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. """ assert ( method in _method_to_function @@ -600,3 +666,118 @@ def detect_artifact_periods( artifact_periods = _method_to_function[method](recording, job_kwargs=job_kwargs, **method_kwargs) return artifact_periods + + +detect_artifact_periods.__doc__ = detect_artifact_periods.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) + + +class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): + """ + Detect and remove artifact periods using one of several available methods. + + Available methods: + + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. + + Method-specific parameters include: + + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + noise_levels_kwargs : dict | None, default: None + Keyword arguments for `spikeinterface.core.get_noise_levels()` function. + + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. + mode : "zeros" | "noise" | "apodization", default: "zeros" + Determines what periods are replaced by. Can be one of the following: + + - "zeros": Artifacts are replaced by zeros. + + - "noise": The periods are filled with a gaussion noise that has the + same variance that the one in the recordings, on a per channel + basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) + apodization_factor : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + artifact_periods : np.ndarray | None, default: None + Optionally, pre-computed artifact periods can be passed directly to the constructor to skip the + detection step. If ``None``, artifact periods are detected on the fly using the specified method + """ + + def __init__( + self, + recording: BaseRecording, + recording_to_detect: BaseRecording | None = None, + method: Literal["envelope", "saturation"] = "envelope", + method_kwargs: dict | None = None, + job_kwargs: dict | None = None, + mode: Literal["zeros", "noise", "apodization"] = "zeros", + noise_levels_kwargs: dict | None = None, + apodization: int = 7, + seed: int | None = None, + artifact_periods=None, + ) -> None: + if artifact_periods is not None: + artifact_periods = artifact_periods + else: + if recording_to_detect is None: + recording_to_detect = recording + artifact_periods = detect_artifact_periods( + recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs + ) + super().__init__( + recording, + periods=artifact_periods, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + apodization=apodization, + ) + + self._kwargs = dict( + recording=recording, + recording_to_detect=recording_to_detect, + method=method, + method_kwargs=method_kwargs, + job_kwargs=job_kwargs, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + artifact_periods=artifact_periods, + apodization=apodization, + ) + + +# function for API +detect_and_remove_artifacts = define_function_handling_dict_from_class( + source_class=DetectAndRemoveArtifactsRecording, name="detect_and_remove_artifacts" +) + +detect_and_remove_artifacts.__doc__ = detect_and_remove_artifacts.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index 47e3c0906b..6ada71de44 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -48,6 +48,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed +from .detect_artifacts import DetectAndRemoveArtifactsRecording, detect_and_remove_artifacts # from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts @@ -72,6 +73,8 @@ # bad channel detection/interpolation DetectAndRemoveBadChannelsRecording: detect_and_remove_bad_channels, DetectAndInterpolateBadChannelsRecording: detect_and_interpolate_bad_channels, + # artifact/saturation handling + DetectAndRemoveArtifactsRecording: detect_and_remove_artifacts, # misc RectifyRecording: rectify, ClipRecording: clip, diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 393c712919..5fadc1969c 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -1,4 +1,5 @@ import numpy as np +import scipy.signal from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -21,14 +22,11 @@ class SilencedPeriodsRecording(BasePreprocessor): ---------- recording : RecordingExtractor The recording extractor to silance periods - list_periods : list of lists/arrays - One list per segment of tuples (start_frame, end_frame) to silence - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise, default: "zeros" + periods : np.array + A numpy array with dtype `base_period_dtype` and fields + "segment_index", "start_sample_index", "end_sample_index". + Each row corresponds to a period to silence. + mode : "zeros" | "noise" | "apodization", default: "zeros" Determines what periods are replaced by. Can be one of the following: - "zeros": Artifacts are replaced by zeros. @@ -36,6 +34,14 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) + apodization_factor : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + noise_levels : array + Noise levels if already computed + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function Returns @@ -48,14 +54,15 @@ def __init__( self, recording, periods=None, - # this is keep for backward compatibility + # this is kept for backward compatibility list_periods=None, mode="zeros", + apodization_factor=7, noise_levels=None, seed=None, **noise_levels_kwargs, ): - available_modes = ("zeros", "noise") + available_modes = ("zeros", "noise", "apodization") num_seg = recording.get_num_segments() # handle backward compatibility with previous version @@ -108,11 +115,23 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index + parent_segment, + periods_in_seg, + mode, + noise_generator, + seg_index, + apodization_factor=apodization_factor, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels) + self._kwargs = dict( + recording=recording, + periods=periods, + mode=mode, + seed=seed, + noise_levels=noise_levels, + apodization_factor=apodization_factor, + ) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -154,12 +173,13 @@ def _check_periods(periods, num_seg): class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index): + def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization_factor=7): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.periods = periods self.mode = mode self.seg_index = seg_index self.noise_generator = noise_generator + self.apodization_factor = apodization_factor def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) @@ -185,7 +205,13 @@ def get_traces(self, start_frame, end_frame, channel_indices): :, channel_indices ] traces[onset:offset, :] = noise[onset:offset] - + elif self.mode == "apodization": + # apply a cosine taper to the saturation to create a mute function + mute = np.zeros(traces.shape[0], dtype=np.float32) + mute[onset:offset] = 1 + win = scipy.signal.windows.cosine(self.apodization_factor) + mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same")) + traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) return traces diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 1d99206bd2..76347f56b8 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -5,6 +5,7 @@ detect_artifact_periods, detect_saturation_periods, detect_artifact_periods_by_envelope, + detect_and_remove_artifacts, ) @@ -238,6 +239,119 @@ def test_detect_saturation_periods(debug_plots): assert np.array_equal(periods, periods_entry_with_annotation) +def test_detect_saturation_signed(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + # Inject positive saturation in first third, negative in second third + pos_start, pos_stop = 15000, 15500 + neg_start, neg_stop = 45000, 45500 + data[pos_start:pos_stop, :] = sat_value + data[neg_start:neg_stop, :] = -sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + periods = detect_saturation_periods( + recording, + saturation_threshold_uV=sat_value * 0.98, + signed=True, + job_kwargs=job_kwargs, + ) + + # Output dtype must include the "sign" field + assert "sign" in periods.dtype.names + + pos_periods = periods[periods["sign"] == "positive"] + neg_periods = periods[periods["sign"] == "negative"] + assert len(pos_periods) > 0, "No positive saturation periods detected" + assert len(neg_periods) > 0, "No negative saturation periods detected" + + # Positive period should be near the injected positive saturation + tolerance = 1 + assert np.any(np.abs(pos_periods["start_sample_index"] - pos_start) <= tolerance) + assert np.any(np.abs(pos_periods["end_sample_index"] - pos_stop) <= tolerance) + + # Negative period should be near the injected negative saturation + assert np.any(np.abs(neg_periods["start_sample_index"] - neg_start) <= tolerance) + assert np.any(np.abs(neg_periods["end_sample_index"] - neg_stop) <= tolerance) + + # Positive periods must not contain any sample indices from the negative injection + for p in pos_periods: + assert not (p["start_sample_index"] < neg_stop and p["end_sample_index"] > neg_start) + + # Negative periods must not contain any sample indices from the positive injection + for p in neg_periods: + assert not (p["start_sample_index"] < pos_stop and p["end_sample_index"] > pos_start) + + +def test_detect_and_remove_artifacts(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + sat_start, sat_stop = 15000, 15500 + data[sat_start:sat_stop, :] = sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + # Basic usage: detect and zero out saturation in one step + cleaned = detect_and_remove_artifacts( + recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + traces = cleaned.get_traces(segment_index=0) + assert traces[sat_start + 100, 0] == 0, "Saturated samples should be zeroed" + assert traces[0, 0] != 0, "Non-saturated samples should not be zeroed" + + # recording_to_detect: detect on raw recording, silence a separate (processed) recording + # We use the same recording here just to exercise the code path + cleaned_with_detect = detect_and_remove_artifacts( + recording, + recording_to_detect=recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + assert np.array_equal(cleaned.get_traces(), cleaned_with_detect.get_traces()) + + if __name__ == "__main__": # test_detect_artifact_by_envelope(True) test_detect_saturation_periods(False) diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index 44bd205f1b..4eee8646dc 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -39,7 +39,6 @@ def test_silence(create_cache_folder): data1 = rec.get_traces(0, 400, 600) data2 = rec.get_traces(0, 500, 700) assert np.all(data1[100:] == data2[:100]) - traces_mix = rec0.get_traces(segment_index=0, start_frame=900, end_frame=5100) traces_original = rec.get_traces(segment_index=0, start_frame=900, end_frame=5100) assert np.all(traces_original[100:-100] == traces_mix[100:-100]) @@ -48,6 +47,26 @@ def test_silence(create_cache_folder): assert not np.all(traces_mix[:200] == 0) assert not np.all(traces_mix[:-200] == 0) + # test that the apodization creates a taper + apodization_factor = 10 + rec2 = silence_periods(rec, periods=periods, mode="apodization", apodization_factor=apodization_factor) + rec2 = rec2.save(format="memory", verbose=False, overwrite=True) + traces_in0 = rec2.get_traces(segment_index=0, start_frame=0, end_frame=1000) + traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000, end_frame=6000) + # all apodized traces + assert np.all(traces_in0 == 0) + assert np.all(traces_in1 == 0) + + # at margins, traces should not be all zero, but should be apodized + apodized_traces_in0 = rec2.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) + apodized_traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) + traces_raw_in0 = rec.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) + traces_raw_in1 = rec.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) + # the apodized traces should be less than the raw traces in absolute value, + # since they are multiplied by a cosine taper between 0 and 1 + assert np.all(np.abs(apodized_traces_in0) <= np.abs(traces_raw_in0)) + assert np.all(np.abs(apodized_traces_in1) <= np.abs(traces_raw_in1)) + if __name__ == "__main__": cache_folder = Path(__file__).resolve().parents[4] / "cache_folder"