diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1609f11d17..10e4885606 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,11 +489,13 @@ def check_graph(nodes): Check that node list is orderd in a good (parents are before children) """ - node0 = nodes[0] - if not isinstance(node0, PeakSource): - raise ValueError( - "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" - ) + # Do not remove this, this is to remenber that in previous version the first node needed to be + # a detectot but not anymore + # node0 = nodes[0] + # if not isinstance(node0, PeakSource): + # raise ValueError( + # "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" + # ) for i, node in enumerate(nodes): assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index de25944bd2..ab1adb6942 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -20,6 +20,12 @@ PreprocessingPipeline, ) +from .detect_artifacts import ( + detect_artifact_periods, + detect_artifact_periods_by_envelope, + detect_saturation_periods +) + # for snippets from .align_snippets import AlignSnippets from warnings import warn diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py new file mode 100644 index 0000000000..3e42facdc5 --- /dev/null +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import numpy as np + +from spikeinterface.core.base import base_period_dtype +# from spikeinterface.core.core_tools import define_function_handling_dict_from_class +# from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording +from spikeinterface.preprocessing.rectify import RectifyRecording +from spikeinterface.preprocessing.common_reference import CommonReferenceRecording +from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode +import numpy as np + + + +artifact_dtype = base_period_dtype + + +# this will be extend with channel boundaries if needed +# extended_artifact_dtype = artifact_dtype + [ +# # TODO +# ] + + + +def detect_artifact_periods( + recording, + method="envelope", + method_kwargs=None, + job_kwargs=None, +): + """ + Detect artifacts with several possible methods: + * 'saturation' using detect_artifact_periods_by_envelope() + * 'envelope' using detect_saturation_periods() + + See sub methods for more information on parameters. + """ + + if method_kwargs is None: + method_kwargs = dict() + + if method == "envelope": + artifact_periods, envelope = detect_artifact_periods_by_envelope(recording, **method_kwargs, job_kwargs=job_kwargs) + elif method == "saturation": + artifact_periods = detect_saturation_periods(recording, **method_kwargs, job_kwargs=job_kwargs) + else: + raise ValueError(f"detect_artifact_periods() method='{method}' is not valid") + + return artifact_periods + + + +## detect_period_artifacts_saturation Zone + +def _collapse_events(events): + """ + If events are detected at a chunk edge, they will be split in two. + This detects such cases and collapses them in a single record instead. + """ + order = np.lexsort((events["start_sample_index"], events["segment_index"])) + events = events[order] + to_drop = np.zeros(events.size, dtype=bool) + + # compute if duplicate + for i in np.arange(events.size - 1): + same = events["end_sample_index"][i] == events["start_sample_index"][i + 1] + if same: + to_drop[i] = True + events["start_sample_index"][i + 1] = events["start_sample_index"][i] + + return events[~to_drop].copy() + + +class _DetectSaturation(PipelineNode): + + name = "detect_saturation" + preferred_mp_context = None + _compute_has_extended_signature = True + + def __init__( + self, + recording, + saturation_threshold_uV, + voltage_per_sec_threshold, + proportion, + ): + PipelineNode.__init__(self, recording, return_output=True) + + gains = recording.get_channel_gains() + offsets = recording.get_channel_offsets() + num_chans = recording.get_num_channels() + + self.voltage_per_sec_threshold = voltage_per_sec_threshold + thresh = np.full((num_chans, ), saturation_threshold_uV) + # 0.98 is empirically determined as the true saturating point is + # slightly lower than the documented saturation point of the probe + self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 + + self.sampling_frequency = recording.get_sampling_frequency() + self.proportion = proportion + self._dtype = np.dtype(artifact_dtype) + self.gain = recording.get_channel_gains() + self.offset = recording.get_channel_offsets() + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return self._dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + + saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) + + if self.voltage_per_sec_threshold is not None: + fs = self.sampling_frequency + # then compute the derivative of the voltage saturation + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.voltage_per_sec_threshold, axis=1) + # Note this means the velocity is not checked for the last sample in the + # check because we are taking the forward derivative + n_diff_saturated = np.r_[n_diff_saturated, 0] + + # if either of those reaches more than the proportion of channels labels the sample as saturated + saturation = np.logical_or(saturation > self.proportion, n_diff_saturated > self.proportion) + else: + saturation = saturation > self.proportion + + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=artifact_dtype) + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index + + return (events, ) + + +def detect_saturation_periods( + recording, + saturation_threshold_uV, # 1200 uV + voltage_per_sec_threshold=None, # 1e-8 V.s-1 + proportion=0.5, + job_kwargs=None, +): + """ + Detect amplifier saturation events (either single sample or multi-sample periods) in the data. + Saturation detection with this function should be applied to the raw data, before preprocessing. + However, saturation periods detected should be zeroed out after preprocessing has been performed. + + Saturation is detected by a voltage threshold, and optionally a derivative threshold that + flags periods of high velocity changes in the voltage. See _DetectSaturation.compute() + for details on the algorithm. + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect the saturation events. + saturation_threshold_uV : float + The voltage saturation threshold in volts. This will depend on the recording + probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). + Note that NP2 probes are more difficult to saturate than NP1. + voltage_per_sec_threshold : None | float + The first-derivative threshold in volts per second. Periods of the data over which the change + in velocity is greater than this threshold will be detected as saturation events. Use `None` to + skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be + empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. + + proportion : + mute_window_samples : + job_kwargs : + + most useful for NP1 + can use ratio as a intuition for the value but dont do it in code + + Returns + ------- + + """ + if job_kwargs: + job_kwargs = {} + + # if saturation_threshold_uV < 0.1: + # raise ValueError(f"The `saturation_threshold_uV` should be in microvolts. " + # f"Your value: {saturation_threshold_uV} is almost certainly in volts.") + + job_kwargs = fix_job_kwargs(job_kwargs) + + node0 = _DetectSaturation( + recording, + saturation_threshold_uV=saturation_threshold_uV, + voltage_per_sec_threshold=voltage_per_sec_threshold, + proportion=proportion, + ) + + saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts") + + return _collapse_events(saturation_periods) + + + +## detect_artifact_periods_by_envelope Zone + +class _DetectThresholdCrossing(PeakDetector): + + name = "threshold_crossings" + preferred_mp_context = None + + def __init__( + self, + recording, + detect_threshold=5, + noise_levels=None, + seed=None, + noise_levels_kwargs=dict(), + ): + PeakDetector.__init__(self, recording, return_output=True) + if noise_levels is None: + random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + self.abs_thresholds = noise_levels * detect_threshold + # internal dtype + self._dtype = np.dtype([ + ("sample_index", "int64"), + ("segment_index", "int64"), + ("front", "bool") + ] + ) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return self._dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + z = np.median(traces / self.abs_thresholds, 1) + threshold_mask = np.diff((z > 1) != 0, axis=0) + indices = np.flatnonzero(threshold_mask) + threshold_crossings = np.zeros(indices.size, dtype=self._dtype) + threshold_crossings["sample_index"] = indices + threshold_crossings["segment_index"] = segment_index + threshold_crossings["front"][::2] = True + threshold_crossings["front"][1::2] = False + return (threshold_crossings,) + + +def detect_artifact_periods_by_envelope( + recording, + detect_threshold=5, + # min_duration_ms=50, + freq_max=20.0, + seed=None, + job_kwargs=None, + random_slices_kwargs=None, +): + """ + Function to detect putative artifact periods as threshold crossings of + a global envelope of the channels. + + Parameters + ---------- + recording : RecordingExtractor + The recording extractor to detect putative artifacts + detect_threshold : float, default: 5 + The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` + freq_max : float, default: 20 + The maximum frequency for the low pass filter used + seed : int | None, default: None + Random seed for `get_noise_levels`. + If none, `get_noise_levels` uses `seed=0`. + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function + + """ + + envelope = RectifyRecording(recording) + envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) + envelope = CommonReferenceRecording(envelope) + + job_kwargs = fix_job_kwargs(job_kwargs) + if random_slices_kwargs is None: + random_slices_kwargs = {} + else: + random_slices_kwargs = random_slices_kwargs.copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + + node0 = _DetectThresholdCrossing( + envelope, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + ) + + threshold_crossings = run_node_pipeline( + envelope, + [node0], + job_kwargs, + job_name="detect artifact on envelope", + ) + + order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) + threshold_crossings = threshold_crossings[order] + + artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) + + return artifacts, envelope + + +def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): + + num_seg = recording.get_num_segments() + + final_artifacts = [] + for seg_index in range(num_seg): + mask = artifacts["segment_index"] == seg_index + sub_thr = artifacts[mask] + if len(sub_thr) > 0: + if not sub_thr["front"][0]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = 0 + local_thr["front"] = True + sub_thr = np.hstack((local_thr, sub_thr)) + if sub_thr["front"][-1]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = recording.get_num_samples(seg_index) + local_thr["front"] = False + sub_thr = np.hstack((sub_thr, local_thr)) + + local_artifact = np.zeros(sub_thr.size/2, dtype=artifact_dtype) + local_artifact["start_index"] = sub_thr["sample_index"][::2] + local_artifact["stop_index"] = sub_thr["sample_index"][1::2] + local_artifact["segment_index"] = seg_index + final_artifacts.append(local_artifact) + + if len(final_artifacts) > 0: + final_artifacts = np.concatenate(final_artifacts) + else: + final_artifacts = np.zeros(0, dtype=artifact_dtype) + return final_artifacts \ No newline at end of file diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index fe9d95c506..47839db7a0 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -50,7 +50,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed -from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts +# from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts _all_preprocesser_dict = { # filter stuff @@ -90,7 +90,7 @@ DirectionalDerivativeRecording: directional_derivative, AstypeRecording: astype, UnsignedToSignedRecording: unsigned_to_signed, - SilencedArtifactsRecording: silence_artifacts, + # SilencedArtifactsRecording: silence_artifacts, } # we control import in the preprocessing init by setting an __all__ diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py deleted file mode 100644 index f8323b2e78..0000000000 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from spikeinterface.core.base import base_peak_dtype -from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector -from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording -from spikeinterface.preprocessing.rectify import RectifyRecording -from spikeinterface.preprocessing.common_reference import CommonReferenceRecording -from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording - - -class DetectThresholdCrossing(PeakDetector): - - name = "threshold_crossings" - preferred_mp_context = None - - def __init__( - self, - recording, - detect_threshold=5, - noise_levels=None, - seed=None, - noise_levels_kwargs=dict(), - ): - PeakDetector.__init__(self, recording, return_output=True) - if noise_levels is None: - random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("front", "bool")]) - - def get_trace_margin(self): - return 0 - - def get_dtype(self): - return self._dtype - - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = np.median(traces / self.abs_thresholds, 1) - threshold_mask = np.diff((z > 1) != 0, axis=0) - indices = np.flatnonzero(threshold_mask) - threshold_crossings = np.zeros(indices.size, dtype=self._dtype) - threshold_crossings["sample_index"] = indices - threshold_crossings["front"][::2] = True - threshold_crossings["front"][1::2] = False - return (threshold_crossings,) - - -def detect_period_artifacts_by_envelope( - recording, - detect_threshold=5, - min_duration_ms=50, - freq_max=20.0, - seed=None, - noise_levels=None, - **noise_levels_kwargs, -): - """ - Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of - a global envelope of the channels. - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to detect putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels`. - If none, `get_noise_levels` uses `seed=0`. - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - """ - - envelope = RectifyRecording(recording) - envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) - envelope = CommonReferenceRecording(envelope) - - from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - ) - - _, job_kwargs = split_job_kwargs(noise_levels_kwargs) - job_kwargs = fix_job_kwargs(job_kwargs) - - node0 = DetectThresholdCrossing( - recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - threshold_crossings = run_node_pipeline( - recording, - [node0], - job_kwargs, - job_name="detect threshold crossings", - ) - - order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) - threshold_crossings = threshold_crossings[order] - - periods = [] - fs = recording.sampling_frequency - max_duration_samples = int(min_duration_ms * fs / 1000) - num_seg = recording.get_num_segments() - - for seg_index in range(num_seg): - sub_periods = [] - mask = threshold_crossings["segment_index"] == seg_index - sub_thr = threshold_crossings[mask] - if len(sub_thr) > 0: - local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) - if not sub_thr["front"][0]: - local_thr["sample_index"] = 0 - local_thr["front"] = True - sub_thr = np.hstack((local_thr, sub_thr)) - if sub_thr["front"][-1]: - local_thr["sample_index"] = recording.get_num_samples(seg_index) - local_thr["front"] = False - sub_thr = np.hstack((sub_thr, local_thr)) - - indices = np.flatnonzero(np.diff(sub_thr["front"])) - for i, j in zip(indices[:-1], indices[1:]): - if sub_thr["front"][i]: - start = sub_thr["sample_index"][i] - end = sub_thr["sample_index"][j] - if end - start > max_duration_samples: - sub_periods.append((start, end)) - - periods.append(sub_periods) - - return periods, envelope - - -class SilencedArtifactsRecording(SilencedPeriodsRecording): - """ - Silence user-defined periods from recording extractor traces. The code will construct - an enveloppe of the recording (as a low pass filtered version of the traces) and detect - threshold crossings to identify the periods to silence. The periods are then silenced either - on a per channel basis or across all channels by replacing the values by zeros or by - adding gaussian noise with the same variance as the one in the recordings - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to silence putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise", default: "zeros" - Determines what periods are replaced by. Can be one of the following: - - - "zeros": Artifacts are replaced by zeros. - - - "noise": The periods are filled with a gaussion noise that has the - same variance that the one in the recordings, on a per channel - basis - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - Returns - ------- - silenced_recording : SilencedArtifactsRecording - The recording extractor after silencing detected artifacts - """ - - _precomputable_kwarg_names = ["list_periods"] - - def __init__( - self, - recording, - detect_threshold=5, - verbose=False, - freq_max=20.0, - min_duration_ms=50, - mode="zeros", - noise_levels=None, - seed=None, - list_periods=None, - **noise_levels_kwargs, - ): - - if list_periods is None: - list_periods, _ = detect_period_artifacts_by_envelope( - recording, - detect_threshold=detect_threshold, - min_duration_ms=min_duration_ms, - freq_max=freq_max, - seed=seed, - noise_levels=noise_levels, - **noise_levels_kwargs, - ) - - if verbose: - for i, periods in enumerate(list_periods): - total_time = np.sum([end - start for start, end in periods]) - percentage = 100 * total_time / recording.get_num_samples(i) - print(f"{percentage}% of segment {i} has been flagged as artifactual") - - SilencedPeriodsRecording.__init__( - self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - -# function for API -silence_artifacts = define_function_handling_dict_from_class( - source_class=SilencedArtifactsRecording, name="silence_artifacts" -) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c9b6e2abe4..040e1275be 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -8,6 +8,8 @@ from spikeinterface.core import get_noise_levels from spikeinterface.core.generate import NoiseGeneratorRecording from spikeinterface.core.job_tools import split_job_kwargs +from spikeinterface.core.base import base_period_dtype + class SilencedPeriodsRecording(BasePreprocessor): @@ -48,7 +50,9 @@ class SilencedPeriodsRecording(BasePreprocessor): def __init__( self, recording, - list_periods, + periods=None, + # this is keep for backward compatibility + list_periods=None, mode="zeros", noise_levels=None, seed=None, @@ -56,25 +60,27 @@ def __init__( ): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() - if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: - # when unique segment accept list instead of list of list/arrays - list_periods = [list_periods] - # some checks - assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" - assert isinstance(list_periods, list), "'list_periods' must be a list (one per segment)" - assert len(list_periods) == num_seg, "'list_periods' must have the same length as the number of segments" - assert all( - isinstance(list_periods[i], (list, np.ndarray)) for i in range(num_seg) - ), "Each element of 'list_periods' must be array-like" + # handle backward compatibility with previous version + if list_periods is not None: + assert periods is None + periods = _all_period_list_to_periods_vec(list_periods, num_seg) + else: + assert list_periods is None + if not isinstance(periods, np.ndarray): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + if periods.dtype.fields is None: + # this is the old format : list[list[int]] + periods = _all_period_list_to_periods_vec(periods, num_seg) - for periods in list_periods: - if len(periods) > 0: - assert np.all(np.diff(np.array(periods), axis=1) > 0), "t_stops should be larger than t_starts" - assert np.all( - periods[i][1] < periods[i + 1][0] for i in np.arange(len(periods) - 1) - ), "Intervals should not overlap" + # force order + order = np.lexsort((periods["start_sample_index"], periods["segment_index"])) + periods = periods[order] + _check_periods(periods, num_seg) + + # some checks + assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" if mode in ["noise"]: if noise_levels is None: @@ -98,18 +104,57 @@ def __init__( noise_generator = None BasePreprocessor.__init__(self, recording) + + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for seg_index, parent_segment in enumerate(recording._recording_segments): - periods = list_periods[seg_index] - periods = np.asarray(periods, dtype="int64") - periods = np.sort(periods, axis=0) - rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) + i0 = seg_limits[seg_index] + i1 = seg_limits[seg_index+1] + periods_in_seg = periods[i0:i1] + rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods_in_seg, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) self._kwargs = dict( - recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels + recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels ) +def _all_period_list_to_periods_vec(list_periods, num_seg): + if num_seg == 1: + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: + # when unique segment accept list instead of list of list/arrays + list_periods = [list_periods] + size = sum(len(p) for p in list_periods) + periods = np.zeros(size, dtype=base_period_dtype) + start = 0 + for i in range(num_seg): + periods_in_seg = list_periods[i] + stop = start + periods_in_seg.shape[0] + periods[start:stop]["segment_index"] = i + periods[start:stop]["start_sample_index"] = periods_in_seg[:, 0] + periods[start:stop]["end_sample_index"] = periods_in_seg[:, 1] + start = stop + return periods + +def _check_periods(periods, num_seg): + # check dtype + if any(col not in np.dtype(base_period_dtype).fields for col in periods.dtype.fields): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + # check non overlap and non negative + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) + for i in range(num_seg): + i0 = seg_limits[i] + i1 = seg_limits[i+1] + periods_in_seg = periods[i0:i1] + if periods_in_seg.size == 0: + continue + if len(periods) > 0: + if np.any(periods_in_seg["start_sample_index"] > periods_in_seg["end_sample_index"]): + raise ValueError("end_sample_index should be larger than start_sample_index") + if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): + raise ValueError("Intervals should not overlap") + + class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -120,18 +165,20 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) - traces = traces.copy() + if self.periods.size > 0: new_interval = np.array([start_frame, end_frame]) - lower_index = np.searchsorted(self.periods[:, 1], new_interval[0]) - upper_index = np.searchsorted(self.periods[:, 0], new_interval[1]) + + lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) + upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) if upper_index > lower_index: - periods_in_interval = self.periods[lower_index:upper_index] + traces = traces.copy() + periods_in_interval = self.periods[lower_index:upper_index] for period in periods_in_interval: - onset = max(0, period[0] - start_frame) - offset = min(period[1] - start_frame, end_frame) + onset = max(0, period["start_sample_index"] - start_frame) + offset = min(period["end_sample_index"] - start_frame, end_frame) if self.mode == "zeros": traces[onset:offset, :] = 0 @@ -143,8 +190,52 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces - # function for API silence_periods = define_function_handling_dict_from_class( source_class=SilencedPeriodsRecording, name="silence_periods" ) + + + +class DetectArtifactAndSilentPeriodsRecording(SilencedPeriodsRecording): + """ + Class doing artifact detection and lient at the same time. + + See SilencedPeriodsRecording and detect_artifact_periods for details. + """ + + _precomputable_kwarg_names = ["artifacts"] + + def __init__( + self, + recording, + detect_artifact_method="envelope", + detect_artifact_kwargs=dict(), + periods=None, + mode="zeros", + noise_levels=None, + seed=None, + **noise_levels_kwargs, + ): + + if artifacts is None: + from spikeinterface.preprocessing import detect_artifact_periods + artifacts = detect_artifact_periods( + recording, + method=detect_artifact_method, + method_kwargs=detect_artifact_kwargs, + job_kwargs=None, + ) + + SilencedPeriodsRecording.__init__( + self, recording, periods=artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs + ) + # note self._kwargs["periods"] is done by SilencedPeriodsRecording and so the computaion is done once + + + +# function for API +detect_artifacts_and_silent_periods = define_function_handling_dict_from_class( + source_class=DetectArtifactAndSilentPeriodsRecording, name="silence_artifacts" +) + diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py new file mode 100644 index 0000000000..b5d9a18a9b --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -0,0 +1,125 @@ +from spikeinterface.core import generate_recording, NumpyRecording +from spikeinterface.preprocessing import detect_artifact_periods, detect_saturation_periods +import numpy as np + + +def test_detect_artifact_periods(): + # one segment only + rec = generate_recording(durations=[10.0, 10]) + artifacts = detect_artifact_periods(rec, method="envelope", + method_kwargs=dict(detect_threshold=5, freq_max=5.0), + ) + + + +def test_detect_saturation_periods(): + + import scipy.signal + + """ + TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity + we have an extra sample after due to taking the diff on the final saturation mask + this means we always take one sample before and one sample after the diff period, which is fine. + """ + # num_chans = 384 + num_chans = 32 + sample_frequency = 30000 + chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below + job_kwargs = {"chunk_size": chunk_size} + + # cross a chunk boundary. Do not change without changing the below. + sat_value = 1200 + rng = np.random.default_rng() + data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 * 1e-6 + + # Design the Butterworth filter + sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") + + # Apply the filter to the data + data_seg_1 = scipy.signal.sosfiltfilt(sos, data, axis=0) + data_seg_2 = data_seg_1.copy() + + # Add test saturation at the start, end of recording + # as well as across and within chunks (30k samples). + # Two cases which are not tested are a single event + # exactly on the border, as it makes testing complex + # This was checked manually and any future breaking change + # on this function would be extremely unlikely only to break this case. + all_starts = np.array([0, 29950, 45123, 90005, 149500]) + all_stops = np.array([1001, 30011, 45126, 90006, 149999]) + + second_seg_offset = 1 + for start, stop in zip(all_starts, all_stops): + data_seg_1[start : stop, :] = sat_value + # differentiate the second segment for testing purposes + data_seg_2[start : stop + second_seg_offset, :] = sat_value + + # this center the int16 around 0 and saturate on positive + max_ = np.max(np.r_[data_seg_1.flatten(), data_seg_2.flatten()]) + gain = max_ / 2**15 + offset = 0 + + seg_1_int16 = np.clip( + np.rint((data_seg_1 - offset) / gain), + -32768, 32767 + ).astype(np.int16) + seg_2_int16 = np.clip( + np.rint((data_seg_2 - offset) / gain), + -32768, 32767 + ).astype(np.int16) + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.plot(seg_1_int16[:, 0]) + # plt.show() + + recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + periods = detect_saturation_periods( + recording, saturation_threshold_uV=sat_value * 0.98, voltage_per_sec_threshold=1e-8, job_kwargs=job_kwargs + ) + + seg_1_periods = periods[np.where(periods["segment_index"] == 0)] + seg_2_periods = periods[np.where(periods["segment_index"] == 1)] + + # For the start times, all are one sample before the actual saturated + # period starts because the derivative threshold is exceeded at one + # sample before the saturation starts. Therefore this one-sample-offset + # on the start times is an implicit test that the derivative + # threshold is working properly. + for seg_periods in [seg_1_periods, seg_2_periods]: + assert seg_periods["start_sample_index"][0] == all_starts[0] + assert np.array_equal(seg_periods["start_sample_index"][1:], np.array(all_starts)[1:] - 1) + + assert np.array_equal(seg_1_periods["end_sample_index"], np.array(all_stops)) + assert np.array_equal(seg_2_periods["end_sample_index"], np.array(all_stops) + second_seg_offset) + + # Just do a quick test that a threshold slightly over the sat value is not detected. + # In this case we only see the derivative threshold detection. We do not play around with this + # threshold because the derivative threshold is not easy to predict (the baseline sample is random). + periods = detect_saturation_periods( + recording, + saturation_threshold_uV=sat_value * (1 / 0.98), + voltage_per_sec_threshold=1e-8, + job_kwargs=job_kwargs, + ) + assert periods["start_sample_index"][0] == 1000 + assert periods["end_sample_index"][0] == 1001 + + periods = detect_artifact_periods( + recording, + method="saturation", + method_kwargs=dict( + saturation_threshold_uV=sat_value * (1 / 0.98), + voltage_per_sec_threshold=1e-8, + ), + job_kwargs=job_kwargs, + ) + + + +if __name__ == "__main__": + test_detect_artifact_periods() + test_detect_saturation_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py deleted file mode 100644 index 2baa4bf1b3..0000000000 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -import numpy as np - -from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import silence_artifacts - - -def test_silence_artifacts(): - # one segment only - rec = generate_recording(durations=[10.0, 10]) - new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) - - -if __name__ == "__main__": - test_silence_artifacts() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py similarity index 76% rename from src/spikeinterface/preprocessing/tests/test_silence.py rename to src/spikeinterface/preprocessing/tests/test_silence_periods.py index e7aee1a84d..ffba9059a0 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -1,11 +1,12 @@ import pytest from spikeinterface.core import generate_recording - +from spikeinterface.core import get_noise_levels +from spikeinterface.core.base import base_period_dtype from spikeinterface.preprocessing import silence_periods -from spikeinterface.core import get_noise_levels + import numpy as np @@ -18,17 +19,20 @@ def test_silence(create_cache_folder): rec = generate_recording() - rec0 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="zeros", seed=2308) - rec0.save(verbose=False) + periods = np.array([(0, 0, 1000), (0, 5000, 6000)], dtype=base_period_dtype) + rec0 = silence_periods(rec, periods=periods, mode="zeros", seed=2308) + rec0.save(format="memory", verbose=False) traces_in0 = rec0.get_traces(segment_index=0, start_frame=0, end_frame=1000) - traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) - traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert np.all(traces_in0 == 0) + traces_half0 = rec0.get_traces(segment_index=0, start_frame=900, end_frame=1100) + assert np.all(traces_half0[:100] == 0) + traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) assert np.all(traces_in1 == 0) + traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert not np.all(traces_out0 == 0) - rec1 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="noise", seed=2308) - rec1 = rec1.save(folder=cache_folder / "rec_w_noise", verbose=False, overwrite=True) + rec1 = silence_periods(rec, periods=periods, mode="noise", seed=2308) + rec1 = rec1.save(format="memory", verbose=False, overwrite=True) noise_levels = get_noise_levels(rec, return_in_uV=False) traces_in0 = rec1.get_traces(segment_index=0, start_frame=0, end_frame=1000) traces_in1 = rec1.get_traces(segment_index=0, start_frame=5000, end_frame=6000)