From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 1/7] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 61c317aba92608d9f096a3a374bc3d43e27faaba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Mar 2026 10:09:46 -0800 Subject: [PATCH 2/7] Fix OpenEphys tests --- .../extractors/neoextractors/openephys.py | 20 ++++++++++++------- .../extractors/tests/test_neoextractors.py | 3 +++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 1c39a1b97c..1d16df534b 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -351,13 +351,19 @@ def __init__( # Ensure device channel index corresponds to channel_ids probe_channel_names = probe.contact_annotations.get("channel_name", None) if probe_channel_names is not None and not np.array_equal(probe_channel_names, self.channel_ids): - device_channel_indices = [] - probe_channel_names = list(probe_channel_names) - device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) - for i, ch in enumerate(self.channel_ids): - index_in_probe = probe_channel_names.index(ch) - device_channel_indices[index_in_probe] = i - probe.set_device_channel_indices(device_channel_indices) + if set(probe_channel_names) == set(self.channel_ids): + device_channel_indices = [] + probe_channel_names = list(probe_channel_names) + device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) + for i, ch in enumerate(self.channel_ids): + index_in_probe = probe_channel_names.index(ch) + device_channel_indices[index_in_probe] = i + probe.set_device_channel_indices(device_channel_indices) + else: + warnings.warn( + "Channel names in the probe do not match the channel ids from Neo. " + "Cannot set device channel indices, but this might lead to incorrect probe geometries" + ) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index f80f62ebf0..f40b4d05ab 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -121,6 +121,9 @@ class OpenEphysBinaryRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "0"}), ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "1"}), ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "0", "block_index": 0}), + # TODO: block_indices 1/2 of v0.6.x_neuropixels_multiexp_multistream have a mismatch in the channel names between + # the settings files (starting with CH0) and structure.oebin (starting at CH1). + # Currently, the extractor will skip remapping to match order in oebin and settings file, raising a warning ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "1", "block_index": 1}), ( "openephysbinary/v0.6.x_neuropixels_multiexp_multistream", From 49c51dadf9f802e772e83f7ee23a5f33be66a2ac Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 20 Apr 2026 17:34:19 +0200 Subject: [PATCH 3/7] feat: implement DetectAndRemoveArtifacts and signed saturation --- .../preprocessing/detect_artifacts.py | 325 +++++++++++++----- .../preprocessing/preprocessing_classes.py | 3 + .../tests/test_detect_artifacts.py | 114 ++++++ 3 files changed, 357 insertions(+), 85 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2f2c740616..4a4dbf17f3 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -4,19 +4,24 @@ from spikeinterface.core import BaseRecording from spikeinterface.core.base import base_period_dtype +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_random_data_chunks +from spikeinterface.core.node_pipeline import PeakDetector, run_node_pipeline, PipelineNode from spikeinterface.preprocessing.rectify import RectifyRecording from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels, get_random_data_chunks -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording artifact_dtype = base_period_dtype -# this will be extend with channel boundaries if needed -# extended_artifact_dtype = artifact_dtype + [ -# # TODO -# ] +signed_artifact_dtype = np.dtype(artifact_dtype + [("sign", "U8")]) + + +def _indent_docstring(docstring: str, indent: int = 4) -> str: + indent_str = " " * indent + indented_lines = [(indent_str + line) if line.strip() else line for line in docstring.splitlines()] + return "\n".join(indented_lines) def _collapse_events(events: np.ndarray) -> np.ndarray: @@ -77,6 +82,7 @@ def __init__( saturation_threshold_uV: float, diff_threshold_uV: float | None, proportion: float, + signed: bool, ) -> None: """ Parameters @@ -91,6 +97,11 @@ def __init__( proportion : float Fraction of channels that must exceed the threshold for a sample to be labelled as saturated (0 < proportion < 1). + signed : bool + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``. """ PipelineNode.__init__(self, recording, return_output=True) @@ -102,6 +113,7 @@ def __init__( # slightly lower than the documented saturation point of the probe self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion + self.signed = signed self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() @@ -122,6 +134,27 @@ def get_dtype(self) -> np.dtype: """Return the NumPy dtype of the output array produced by :meth:`compute`.""" return self._dtype + def detect_in_chunk(self, traces, saturation_threshold, diff_threshold, proportion) -> np.ndarray: + saturation = np.mean(traces > saturation_threshold, axis=1) + detected_by_value = saturation > proportion + + if diff_threshold is not None: + # then compute the derivative of the voltage saturation + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= diff_threshold, axis=1) + + # Note this means the velocity is not checked for the last sample in the + # check because we are taking the forward derivative + n_diff_saturated = np.r_[n_diff_saturated, 0] + + # if either of those reaches more than the proportion of channels labels the sample as saturated + detected_by_diff = n_diff_saturated > proportion + saturation = np.logical_or(detected_by_value, detected_by_diff) + else: + saturation = detected_by_value + + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) + return intervals + def compute( self, traces: np.ndarray, @@ -165,33 +198,63 @@ def compute( # cast to float32 to prevent overflow when applying thresholds in unscaled ADC units traces = traces.astype("float32") - saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - detected_by_value = saturation > self.proportion - - if self.diff_threshold_unscaled is not None: - # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.diff_threshold_unscaled, axis=1) - - # Note this means the velocity is not checked for the last sample in the - # check because we are taking the forward derivative - n_diff_saturated = np.r_[n_diff_saturated, 0] - - # if either of those reaches more than the proportion of channels labels the sample as saturated - detected_by_diff = n_diff_saturated > self.proportion - saturation = np.logical_or(detected_by_value, detected_by_diff) + if not self.signed: + traces = np.abs(traces) + intervals = self.detect_in_chunk( + traces, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=artifact_dtype) + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index else: - saturation = detected_by_value + all_events = [] + for sign in (1, -1): + traces_signed = sign * traces + intervals = self.detect_in_chunk( + traces_signed, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=signed_artifact_dtype) + events["sign"] = "positive" if sign == 1 else "negative" + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index + all_events.append(events) + all_events = np.concatenate(all_events) + # sort by start sample index + order = np.argsort(all_events["start_sample_index"]) + all_events = all_events[order] + return (all_events,) - intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) - n_events = len(intervals) // 2 # Number of saturation periods - events = np.zeros(n_events, dtype=artifact_dtype) + return (events,) - for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): - events[i]["start_sample_index"] = start + start_frame - events[i]["end_sample_index"] = stop + start_frame - events[i]["segment_index"] = segment_index - return (events,) +_detect_saturation_periods_params = """saturation_threshold_uV : float | None, default: None + Voltage saturation threshold in μV. The appropriate value depends on + the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL + recommend **1200 μV**. NP2 probes are harder to saturate than NP1. + If ``None``, the value is read from the ``"saturation_threshold_uV"`` + annotation of ``recording``. +diff_threshold_uV : float | None, default: None + First-derivative threshold in μV/sample. Periods where the + sample-to-sample voltage change exceeds this value in the required + fraction of channels are flagged as saturation. Pass ``None`` to + disable derivative-based detection and rely solely on + ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. +proportion : float, default: 0.2 + Fraction of channels (0 < proportion < 1) that must exceed the + threshold for a sample to be considered saturated. +signed : bool, default: False + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``.""" def detect_saturation_periods( @@ -199,6 +262,7 @@ def detect_saturation_periods( saturation_threshold_uV: float | None = None, diff_threshold_uV: float | None = None, proportion: float = 0.2, + signed: bool = False, job_kwargs: dict | None = None, ) -> np.ndarray: """ @@ -224,21 +288,7 @@ def detect_saturation_periods( ---------- recording : BaseRecording The recording on which to detect saturation events. - saturation_threshold_uV : float | None, default: None - Voltage saturation threshold in μV. The appropriate value depends on - the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL - recommend **1200 μV**. NP2 probes are harder to saturate than NP1. - If ``None``, the value is read from the ``"saturation_threshold_uV"`` - annotation of ``recording``. - diff_threshold_uV : float | None, default: None - First-derivative threshold in μV/sample. Periods where the - sample-to-sample voltage change exceeds this value in the required - fraction of channels are flagged as saturation. Pass ``None`` to - disable derivative-based detection and rely solely on - ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. - proportion : float, default: 0.2 - Fraction of channels (0 < proportion < 1) that must exceed the - threshold for a sample to be considered saturated. + {} job_kwargs : dict | None, default: None Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. ``n_jobs``, ``chunk_duration``). @@ -270,6 +320,7 @@ def detect_saturation_periods( saturation_threshold_uV=saturation_threshold_uV, diff_threshold_uV=diff_threshold_uV, proportion=proportion, + signed=signed, ) saturation_periods = run_node_pipeline( @@ -278,6 +329,9 @@ def detect_saturation_periods( return _collapse_events(saturation_periods) +detect_saturation_periods.__doc__ = detect_saturation_periods.__doc__.format(_detect_saturation_periods_params) + + ## detect_artifact_periods_by_envelope zone class _DetectThresholdCrossing(PeakDetector): """ @@ -382,6 +436,23 @@ def compute( return (threshold_crossings,) +_detect_artifacts_by_envelope_params = """detect_threshold : float, default: 5 + Detection threshold as a multiple of the estimated per-channel noise + level of the envelope. +freq_max : float, default: 20.0 + Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the + rectified signal when building the envelope. +seed : int | None, default: None + Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. + If ``None``, ``get_noise_levels`` uses ``seed=0``. +job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). +random_slices_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the ``random_slices_kwargs`` + argument of :func:`~spikeinterface.core.get_noise_levels`.""" + + def detect_artifact_periods_by_envelope( recording: BaseRecording, detect_threshold: float = 5, @@ -410,21 +481,7 @@ def detect_artifact_periods_by_envelope( ---------- recording : BaseRecording The recording extractor from which to detect artefact periods. - detect_threshold : float, default: 5 - Detection threshold as a multiple of the estimated per-channel noise - level of the envelope. - freq_max : float, default: 20.0 - Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the - rectified signal when building the envelope. - seed : int | None, default: None - Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. - If ``None``, ``get_noise_levels`` uses ``seed=0``. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). - random_slices_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the ``random_slices_kwargs`` - argument of :func:`~spikeinterface.core.get_noise_levels`. + {} return_envelope : bool, default: False If ``True``, also return the intermediate envelope recording so that it can be inspected or plotted. @@ -475,7 +532,6 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] artifacts = _collapse_events(artifacts) if return_envelope: @@ -484,6 +540,11 @@ def detect_artifact_periods_by_envelope( return artifacts +detect_artifact_periods_by_envelope.__doc__ = detect_artifact_periods_by_envelope.__doc__.format( + _detect_artifacts_by_envelope_params +) + + def _transform_internal_dtype_to_artifact_dtype( artifacts: np.ndarray, recording: BaseRecording, @@ -561,35 +622,41 @@ def detect_artifact_periods( job_kwargs: dict | None = None, ) -> np.ndarray: """ - Detect artifact periods using one of several available methods. + Detect artifact periods using one of several available methods. - Available methods: + Available methods: - * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified - channel envelope. - * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. - See the documentation of each sub-function for a full description of their - parameters, which can be forwarded via ``method_kwargs``. + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. - Parameters - ---------- - recording : BaseRecording - The recording on which to detect artifact periods. - method : {"envelope", "saturation"}, default: "envelope" - Detection method to use. - method_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the selected detection - function. Pass ``None`` to use that function's defaults. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). + Method-specific parameters include: - Returns - ------- - np.ndarray - Array with dtype ``artifact_dtype`` describing each detected artifact - period. + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. """ assert ( method in _method_to_function @@ -600,3 +667,91 @@ def detect_artifact_periods( artifact_periods = _method_to_function[method](recording, job_kwargs=job_kwargs, **method_kwargs) return artifact_periods + + +detect_artifact_periods.__doc__ = detect_artifact_periods.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) + + +class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): + """ + Detect and remove artifact periods using one of several available methods. + + Available methods: + + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. + + Method-specific parameters include: + + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + """ + + def __init__( + self, + recording: BaseRecording, + recording_to_detect: BaseRecording | None = None, + method: Literal["envelope", "saturation"] = "envelope", + method_kwargs: dict | None = None, + job_kwargs: dict | None = None, + mode: Literal["zeros", "noise"] = "zeros", + noise_levels_kwargs: dict | None = None, + seed: int | None = None, + artifact_periods=None, + ) -> None: + if artifact_periods is not None: + artifact_periods = artifact_periods + else: + if recording_to_detect is None: + recording_to_detect = recording + artifact_periods = detect_artifact_periods( + recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs + ) + super().__init__( + recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed + ) + + self._kwargs = dict( + recording=recording, + recording_to_detect=recording_to_detect, + method=method, + method_kwargs=method_kwargs, + job_kwargs=job_kwargs, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + artifact_periods=artifact_periods, + ) + + +# function for API +detect_and_remove_artifacts = define_function_handling_dict_from_class( + source_class=DetectAndRemoveArtifactsRecording, name="detect_and_remove_artifacts" +) + +detect_and_remove_artifacts.__doc__ = detect_and_remove_artifacts.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index 47e3c0906b..6ada71de44 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -48,6 +48,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed +from .detect_artifacts import DetectAndRemoveArtifactsRecording, detect_and_remove_artifacts # from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts @@ -72,6 +73,8 @@ # bad channel detection/interpolation DetectAndRemoveBadChannelsRecording: detect_and_remove_bad_channels, DetectAndInterpolateBadChannelsRecording: detect_and_interpolate_bad_channels, + # artifact/saturation handling + DetectAndRemoveArtifactsRecording: detect_and_remove_artifacts, # misc RectifyRecording: rectify, ClipRecording: clip, diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 1d99206bd2..76347f56b8 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -5,6 +5,7 @@ detect_artifact_periods, detect_saturation_periods, detect_artifact_periods_by_envelope, + detect_and_remove_artifacts, ) @@ -238,6 +239,119 @@ def test_detect_saturation_periods(debug_plots): assert np.array_equal(periods, periods_entry_with_annotation) +def test_detect_saturation_signed(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + # Inject positive saturation in first third, negative in second third + pos_start, pos_stop = 15000, 15500 + neg_start, neg_stop = 45000, 45500 + data[pos_start:pos_stop, :] = sat_value + data[neg_start:neg_stop, :] = -sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + periods = detect_saturation_periods( + recording, + saturation_threshold_uV=sat_value * 0.98, + signed=True, + job_kwargs=job_kwargs, + ) + + # Output dtype must include the "sign" field + assert "sign" in periods.dtype.names + + pos_periods = periods[periods["sign"] == "positive"] + neg_periods = periods[periods["sign"] == "negative"] + assert len(pos_periods) > 0, "No positive saturation periods detected" + assert len(neg_periods) > 0, "No negative saturation periods detected" + + # Positive period should be near the injected positive saturation + tolerance = 1 + assert np.any(np.abs(pos_periods["start_sample_index"] - pos_start) <= tolerance) + assert np.any(np.abs(pos_periods["end_sample_index"] - pos_stop) <= tolerance) + + # Negative period should be near the injected negative saturation + assert np.any(np.abs(neg_periods["start_sample_index"] - neg_start) <= tolerance) + assert np.any(np.abs(neg_periods["end_sample_index"] - neg_stop) <= tolerance) + + # Positive periods must not contain any sample indices from the negative injection + for p in pos_periods: + assert not (p["start_sample_index"] < neg_stop and p["end_sample_index"] > neg_start) + + # Negative periods must not contain any sample indices from the positive injection + for p in neg_periods: + assert not (p["start_sample_index"] < pos_stop and p["end_sample_index"] > pos_start) + + +def test_detect_and_remove_artifacts(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + sat_start, sat_stop = 15000, 15500 + data[sat_start:sat_stop, :] = sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + # Basic usage: detect and zero out saturation in one step + cleaned = detect_and_remove_artifacts( + recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + traces = cleaned.get_traces(segment_index=0) + assert traces[sat_start + 100, 0] == 0, "Saturated samples should be zeroed" + assert traces[0, 0] != 0, "Non-saturated samples should not be zeroed" + + # recording_to_detect: detect on raw recording, silence a separate (processed) recording + # We use the same recording here just to exercise the code path + cleaned_with_detect = detect_and_remove_artifacts( + recording, + recording_to_detect=recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + assert np.array_equal(cleaned.get_traces(), cleaned_with_detect.get_traces()) + + if __name__ == "__main__": # test_detect_artifact_by_envelope(True) test_detect_saturation_periods(False) From 99afff64336c3b246e0da49ee11e99f4ff5ba100 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 20 Apr 2026 18:10:27 +0200 Subject: [PATCH 4/7] Apply suggestion from @alejoe91 --- src/spikeinterface/preprocessing/detect_artifacts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 4a4dbf17f3..a03082747c 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -229,8 +229,7 @@ def compute( all_events = np.concatenate(all_events) # sort by start sample index order = np.argsort(all_events["start_sample_index"]) - all_events = all_events[order] - return (all_events,) + events = all_events[order] return (events,) From 9c25668a211317100e65ab48ec9f61071d39738e Mon Sep 17 00:00:00 2001 From: Olivier Winter Date: Wed, 22 Apr 2026 21:28:46 +0100 Subject: [PATCH 5/7] saturation application with apodization --- .../preprocessing/detect_artifacts.py | 4 +++- .../preprocessing/silence_periods.py | 21 +++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index a03082747c..a68824ca72 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -717,6 +717,7 @@ def __init__( job_kwargs: dict | None = None, mode: Literal["zeros", "noise"] = "zeros", noise_levels_kwargs: dict | None = None, + apodization: int = 7, seed: int | None = None, artifact_periods=None, ) -> None: @@ -729,7 +730,7 @@ def __init__( recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs ) super().__init__( - recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed + recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed, apodization=apodization ) self._kwargs = dict( @@ -742,6 +743,7 @@ def __init__( noise_levels_kwargs=noise_levels_kwargs, seed=seed, artifact_periods=artifact_periods, + apodization=apodization, ) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 393c712919..c77e8cd43a 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -1,4 +1,5 @@ import numpy as np +import scipy.signal from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -48,14 +49,15 @@ def __init__( self, recording, periods=None, - # this is keep for backward compatibility + # this is kept for backward compatibility list_periods=None, mode="zeros", noise_levels=None, + apodization=7, seed=None, **noise_levels_kwargs, ): - available_modes = ("zeros", "noise") + available_modes = ("zeros", "noise", "apodization") num_seg = recording.get_num_segments() # handle backward compatibility with previous version @@ -108,11 +110,11 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index + parent_segment, periods_in_seg, mode, noise_generator, seg_index, apodization=apodization, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels) + self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels, apodization=apodization) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -154,12 +156,13 @@ def _check_periods(periods, num_seg): class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index): + def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization=7): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.periods = periods self.mode = mode self.seg_index = seg_index self.noise_generator = noise_generator + self.apodization = apodization def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) @@ -185,7 +188,13 @@ def get_traces(self, start_frame, end_frame, channel_indices): :, channel_indices ] traces[onset:offset, :] = noise[onset:offset] - + elif self.mode == "apodization": + # apply a cosine taper to the saturation to create a mute function + mute = np.zeros(traces.shape[0], dtype=np.float32) + mute[onset:offset] = 1 + win = scipy.signal.windows.cosine(self.apodization) + mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same")) + traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) return traces From 16368173cb3ed65bb39055ffc8ca294fbc849e7b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 13:26:43 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/detect_artifacts.py | 7 ++++++- .../preprocessing/silence_periods.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index a68824ca72..bb183ccbf4 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -730,7 +730,12 @@ def __init__( recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs ) super().__init__( - recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed, apodization=apodization + recording, + periods=artifact_periods, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + apodization=apodization, ) self._kwargs = dict( diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c77e8cd43a..db984b6572 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -110,11 +110,23 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index, apodization=apodization, + parent_segment, + periods_in_seg, + mode, + noise_generator, + seg_index, + apodization=apodization, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels, apodization=apodization) + self._kwargs = dict( + recording=recording, + periods=periods, + mode=mode, + seed=seed, + noise_levels=noise_levels, + apodization=apodization, + ) def _all_period_list_to_periods_vec(list_periods, num_seg): From 230507fb0dae638ccad123cd11c3ef15547c4b23 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 23 Apr 2026 16:20:20 +0200 Subject: [PATCH 7/7] Add docstrings for apodization and tests --- .../preprocessing/detect_artifacts.py | 29 +++++++++++- .../preprocessing/silence_periods.py | 45 +++++++++++++------ .../tests/test_silence_periods.py | 21 ++++++++- 3 files changed, 78 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index a68824ca72..18cbd2f67d 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -706,6 +706,26 @@ class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): job_kwargs : dict | None, default: None Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. ``n_jobs``, ``chunk_duration``). + noise_levels_kwargs : dict | None, default: None + Keyword arguments for `spikeinterface.core.get_noise_levels()` function. + + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. + mode : "zeros" | "noise" | "apodization", default: "zeros" + Determines what periods are replaced by. Can be one of the following: + + - "zeros": Artifacts are replaced by zeros. + + - "noise": The periods are filled with a gaussion noise that has the + same variance that the one in the recordings, on a per channel + basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) + apodization_factor : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + artifact_periods : np.ndarray | None, default: None + Optionally, pre-computed artifact periods can be passed directly to the constructor to skip the + detection step. If ``None``, artifact periods are detected on the fly using the specified method """ def __init__( @@ -715,7 +735,7 @@ def __init__( method: Literal["envelope", "saturation"] = "envelope", method_kwargs: dict | None = None, job_kwargs: dict | None = None, - mode: Literal["zeros", "noise"] = "zeros", + mode: Literal["zeros", "noise", "apodization"] = "zeros", noise_levels_kwargs: dict | None = None, apodization: int = 7, seed: int | None = None, @@ -730,7 +750,12 @@ def __init__( recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs ) super().__init__( - recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed, apodization=apodization + recording, + periods=artifact_periods, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + apodization=apodization, ) self._kwargs = dict( diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c77e8cd43a..5fadc1969c 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -22,14 +22,11 @@ class SilencedPeriodsRecording(BasePreprocessor): ---------- recording : RecordingExtractor The recording extractor to silance periods - list_periods : list of lists/arrays - One list per segment of tuples (start_frame, end_frame) to silence - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise, default: "zeros" + periods : np.array + A numpy array with dtype `base_period_dtype` and fields + "segment_index", "start_sample_index", "end_sample_index". + Each row corresponds to a period to silence. + mode : "zeros" | "noise" | "apodization", default: "zeros" Determines what periods are replaced by. Can be one of the following: - "zeros": Artifacts are replaced by zeros. @@ -37,6 +34,14 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) + apodization_factor : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + noise_levels : array + Noise levels if already computed + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function Returns @@ -52,8 +57,8 @@ def __init__( # this is kept for backward compatibility list_periods=None, mode="zeros", + apodization_factor=7, noise_levels=None, - apodization=7, seed=None, **noise_levels_kwargs, ): @@ -110,11 +115,23 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index, apodization=apodization, + parent_segment, + periods_in_seg, + mode, + noise_generator, + seg_index, + apodization_factor=apodization_factor, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels, apodization=apodization) + self._kwargs = dict( + recording=recording, + periods=periods, + mode=mode, + seed=seed, + noise_levels=noise_levels, + apodization_factor=apodization_factor, + ) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -156,13 +173,13 @@ def _check_periods(periods, num_seg): class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization=7): + def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization_factor=7): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.periods = periods self.mode = mode self.seg_index = seg_index self.noise_generator = noise_generator - self.apodization = apodization + self.apodization_factor = apodization_factor def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) @@ -192,7 +209,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # apply a cosine taper to the saturation to create a mute function mute = np.zeros(traces.shape[0], dtype=np.float32) mute[onset:offset] = 1 - win = scipy.signal.windows.cosine(self.apodization) + win = scipy.signal.windows.cosine(self.apodization_factor) mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same")) traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) return traces diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index 44bd205f1b..4eee8646dc 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -39,7 +39,6 @@ def test_silence(create_cache_folder): data1 = rec.get_traces(0, 400, 600) data2 = rec.get_traces(0, 500, 700) assert np.all(data1[100:] == data2[:100]) - traces_mix = rec0.get_traces(segment_index=0, start_frame=900, end_frame=5100) traces_original = rec.get_traces(segment_index=0, start_frame=900, end_frame=5100) assert np.all(traces_original[100:-100] == traces_mix[100:-100]) @@ -48,6 +47,26 @@ def test_silence(create_cache_folder): assert not np.all(traces_mix[:200] == 0) assert not np.all(traces_mix[:-200] == 0) + # test that the apodization creates a taper + apodization_factor = 10 + rec2 = silence_periods(rec, periods=periods, mode="apodization", apodization_factor=apodization_factor) + rec2 = rec2.save(format="memory", verbose=False, overwrite=True) + traces_in0 = rec2.get_traces(segment_index=0, start_frame=0, end_frame=1000) + traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000, end_frame=6000) + # all apodized traces + assert np.all(traces_in0 == 0) + assert np.all(traces_in1 == 0) + + # at margins, traces should not be all zero, but should be apodized + apodized_traces_in0 = rec2.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) + apodized_traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) + traces_raw_in0 = rec.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) + traces_raw_in1 = rec.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) + # the apodized traces should be less than the raw traces in absolute value, + # since they are multiplied by a cosine taper between 0 and 1 + assert np.all(np.abs(apodized_traces_in0) <= np.abs(traces_raw_in0)) + assert np.all(np.abs(apodized_traces_in1) <= np.abs(traces_raw_in1)) + if __name__ == "__main__": cache_folder = Path(__file__).resolve().parents[4] / "cache_folder"