diff --git a/doc/index.rst b/doc/index.rst index ce4053ca43..659efe85a8 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -30,7 +30,7 @@ SpikeInterface is made of several modules to deal with different aspects of the - visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, jupyter, ephyviewer) - export a report and/or export to phy - curate your sorting with several strategies (ml-based, metrics based, manual, ...) -- offer a powerful Qt-based or we-based viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. +- offer a powerful desktop or web viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. - have powerful sorting components to build your own sorter. - have a full motion/drift correction framework (See :ref:`motion_correction`) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 078db82201..0faed14fba 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -93,10 +93,11 @@ with 16 channels: timestamps = np.arange(num_samples) / sampling_frequency + 300 recording.set_times(times=timestamps, segment_index=0) -**Note**: -Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units (uV), you can apply a gain and an offset. -Many devices have their own gains and offsets necessary to convert their data and these values are handled by SpikeInterface for its extractors. This -is triggered by the :code:`return_in_uV` parameter in :code:`get_traces()`, (see above example), which will return the traces in uV. Read more in our how to guide, :ref:`physical_units`. +.. note:: + + Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units (uV), you can apply a gain and an offset. + Many devices have their own gains and offsets necessary to convert their data and these values are handled by SpikeInterface for its extractors. This + is triggered by the :code:`return_in_uV` parameter in :code:`get_traces()`, (see above example), which will return the traces in uV. Read more in our how to guide, :ref:`physical_units`. Sorting @@ -180,8 +181,9 @@ a numpy.array with dtype `[("sample_index", "int64"), ("unit_index", "int64"), ( For computations which are done unit-by-unit, like computing isi-violations per unit, it is better that spikes from a single unit are concurrent in memory. For these other cases, we can re-order the `spike_vector` in different ways: - * order by unit, then segment, then sample - * order by segment, then unit, then sample + +* order by unit, then segment, then sample +* order by segment, then unit, then sample This is done using `sorting.to_reordered_spike_vector()`. The first time a reordering is done, the reordered spiketrain is cached in memory by default. Users should rarely have to worry about these @@ -458,9 +460,11 @@ It represents unsorted waveform cutouts. Some acquisition systems, in fact, allo threshold and only record the times at which a peak was detected and the waveform cut out around the peak. -**NOTE**: while we support this class (mainly for legacy formats), this approach is a bad practice -and is highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform -template matching to recover spikes! +.. note:: + + While we support this class (mainly for legacy formats), this approach is a bad practice + and is highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform + template matching to recover spikes! Here we assume :code:`snippets` is a :py:class:`~spikeinterface.core.BaseSnippets` object with 16 channels: @@ -548,9 +552,11 @@ Sparsity is defined as the subset of channels on which waveforms (and related in sparsity is not global, but it is unit-specific. Importantly, saving sparse waveforms, especially for high-density probes, dramatically reduces the size of the waveforms extension if computed. -**NOTE** As of :code:`0.101.0` all :code:`SortingAnalyzer`'s have a default of :code:`sparse=True`. This was first -introduced in :code:`0.99.0` for :code:`WaveformExtractor`'s and will be the default going forward. To obtain dense -waveforms you will need to set :code:`sparse=False` at the creation of the :code:`SortingAnalyzer`. +.. note:: + + As of :code:`0.101.0` all :code:`SortingAnalyzer`'s have a default of :code:`sparse=True`. This was first + introduced in :code:`0.99.0` for :code:`WaveformExtractor`'s and will be the default going forward. To obtain dense + waveforms you will need to set :code:`sparse=False` at the creation of the :code:`SortingAnalyzer`. Sparsity can be computed from a :py:class:`~spikeinterface.core.SortingAnalyzer` object with the @@ -854,10 +860,12 @@ The same functions are also available for :py:func:`~spikeinterface.core.select_segment_sorting`). -**Note** :py:func:`~spikeinterface.core.append_recordings` and:py:func:`~spikeinterface.core.concatenate_recordings` -have the same goal, aggregate recording pieces on the time axis but with 2 different strategies! One is keeping the -multi segments concept, the other one is breaking it! -See this example for more detail :ref:`example_segments`. +.. note:: + + :py:func:`~spikeinterface.core.append_recordings` and:py:func:`~spikeinterface.core.concatenate_recordings` + have the same goal, aggregate recording pieces on the time axis but with 2 different strategies! One is keeping the + multi segments concept, the other one is breaking it! + See this example for more detail :ref:`example_segments`. diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 4eeaa23b81..ab25e8803d 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -12,9 +12,11 @@ and behavioral data. It can be used to decode behavior, make tuning curves, comp The :py:func:`~spikeinterface.exporters.to_pynapple_tsgroup` function allows you to convert a SortingAnalyzer to Pynapple's ``TsGroup`` object on the fly. -**Note** : When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer. -How this works depends on your acquisition system. You can use the ``get_times`` method on a recording -(``my_recording.get_times()``) to find the time support of your recording. +.. note:: + + When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer. + How this works depends on your acquisition system. You can use the ``get_times`` method on a recording + (``my_recording.get_times()``) to find the time support of your recording. When constructed, if ``attach_unit_metadata`` is set to ``True``, any relevant unit information is propagated to the ``TsGroup``. The ``to_pynapple_tsgroup`` checks if unit locations, quality @@ -54,13 +56,15 @@ The :py:func:`~spikeinterface.exporters.export_to_phy` function allows you to us `Phy template GUI `_ for visual inspection and manual curation of spike sorting results. -**Note** : :py:func:`~spikeinterface.exporters.export_to_phy` speed and the size of the folder will highly depend -on the sparsity of the :code:`SortingAnalyzer` itself or the external specified sparsity. -The Phy viewer enables one to explore PCA projections, spike amplitudes, waveforms and quality of spike sorting results. -So if these pieces of information have already been computed as extensions (see :ref:`modules/postprocessing:Extensions as AnalyzerExtensions`), -then exporting to Phy should be fast (and the user has better control of the parameters for the extensions). -If not pre-computed, then the required extensions (e.g., :code:`spike_amplitudes`, :code:`principal_components`) -can be computed directly at export time. +.. note:: + + :py:func:`~spikeinterface.exporters.export_to_phy` speed and the size of the folder will highly depend + on the sparsity of the :code:`SortingAnalyzer` itself or the external specified sparsity. + The Phy viewer enables one to explore PCA projections, spike amplitudes, waveforms and quality of spike sorting results. + So if these pieces of information have already been computed as extensions (see :ref:`modules/postprocessing:Extensions as AnalyzerExtensions`), + then exporting to Phy should be fast (and the user has better control of the parameters for the extensions). + If not pre-computed, then the required extensions (e.g., :code:`spike_amplitudes`, :code:`principal_components`) + can be computed directly at export time. The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:`SortingAnalyzer` object. @@ -131,12 +135,14 @@ The report includes summary figures of the spike sorting output (e.g. amplitude depth VS amplitude) as well as unit-specific reports, that include waveforms, templates, template maps, ISI distributions, and more. -**Note** : similarly to :py:func:`~spikeinterface.exporters.export_to_phy` the -:py:func:`~spikeinterface.exporters.export_report` depends on the sparsity of the :code:`SortingAnalyzer` itself and -on which extensions have been computed. For example, :code:`spike_amplitudes` and :code:`correlograms` related plots -will be automatically included in the report if the associated extensions are computed in advance. -The function can perform these computations as well, but it is a better practice to compute everything that's needed -beforehand. +.. note:: + + Similarly to :py:func:`~spikeinterface.exporters.export_to_phy` the + :py:func:`~spikeinterface.exporters.export_report` depends on the sparsity of the :code:`SortingAnalyzer` itself and + on which extensions have been computed. For example, :code:`spike_amplitudes` and :code:`correlograms` related plots + will be automatically included in the report if the associated extensions are computed in advance. + The function can perform these computations as well, but it is a better practice to compute everything that's needed + beforehand. Note that every unit will generate a summary unit figure, so the export process can be slow for spike sorting outputs with many units! diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 5442b4728c..5c4e29b359 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -163,8 +163,10 @@ Extensions are generally saved in two ways, suitable for two workflows: :code:`sorting_analyzer.compute('waveforms', save=False)`). -**NOTE**: We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until -you'd like to save. A mixture can lead to unexpected behavior. For example, consider the following code +.. note:: + + We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until + you'd like to save. A mixture can lead to unexpected behavior. For example, consider the following code .. code:: @@ -257,15 +259,35 @@ spike_amplitudes This extension computes the amplitude of each spike as the value of the traces on the extremum channel at the times of each spike. The extremum channel is computed from the templates. + **NOTE:** computing spike amplitudes is highly recommended before calculating amplitude-based quality metrics, such as :ref:`amp_cutoff` and :ref:`amp_median`. .. code-block:: python - amplitudes = sorting_analyzer.compute(input="spike_amplitudes", peak_sign="neg") + amplitudes = sorting_analyzer.compute(input="spike_amplitudes") For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes` + +.. _postprocessing_amplitude_scalings: + +amplitude_scalings +^^^^^^^^^^^^^^^^^^ + +This extension computes the amplitude scaling of each spike as the value of the linear fit between the template and the +spike waveform. In case of spatio-temporal collisions, a multi-linear fit is performed using the templates of all units +involved in the collision. + +**NOTE:** computing amplitude scalings is highly recommended before calculating amplitude-based quality metrics, such as +:ref:`amp_cutoff` and :ref:`amp_median`. + +.. code-block:: python + + amplitude_scalings = sorting_analyzer.compute(input="amplitude_scalings") + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings` + .. _postprocessing_spike_locations: spike_locations @@ -367,7 +389,29 @@ This extension computes the histograms of inter-spike-intervals. The computed ou method="auto" ) -For more information, see :py:func:`~spikeinterface.postprocessing.compute_isi_histograms` +valid_unit_periods +^^^^^^^^^^^^^^^^^^ + +This extension computes the valid unit periods for each unit based on the estimation of false positive rates +(using RP violation - see ::doc:`metrics/qualitymetrics/isi_violations`) and false negative rates +(using amplitude cutoff - see ::doc:`metrics/qualitymetrics/amplitude_cutoff`) computed over chunks of the recording. +The valid unit periods are the periods where both false positive and false negative rates are below specified +thresholds. Periods can be either absolute (in seconds), same for all units, or relative, where +chunks will be unit-specific depending on firing rate (with a target number of spikes per chunk). + +.. code-block:: python + + valid_periods = sorting_analyzer.compute( + input="valid_unit_periods", + period_mode='relative', + target_num_spikes=300, + fp_threshold=0.1, + fn_threshold=0.1, + ) + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_valid_unit_periods`. + + Other postprocessing tools diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index e76cf3f99d..5549fd0317 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -81,7 +81,9 @@ Other variants are also implemented (but less tested or not so useful): * **'by_channel_torch'** (requires :code:`torch`): pytorch implementation (GPU-compatible) that uses max pooling for time deduplication * **'locally_exclusive_torch'** (requires :code:`torch`): pytorch implementation (GPU-compatible) that uses max pooling for space-time deduplication -**NOTE**: the torch implementations give slightly different results due to a different implementation. +.. note:: + + The torch implementations give slightly different results due to a different implementation. Peak detection, as many of the other sorting components, can be run in parallel. @@ -274,7 +276,7 @@ handle drift can benefit from drift estimation/correction. Especially for acute Neuropixels-like probes, this is a crucial step. The motion estimation step comes after peak detection and peak localization. Read more about -it in the :ref:`_motion_correction` modules doc, and a more practical guide in the +it in the :ref:`motion_correction` modules doc, and a more practical guide in the :ref:`handle-drift-in-your-recording` How To. Here is an example with non-rigid motion estimation: diff --git a/doc/references.rst b/doc/references.rst index be05b69c0c..49f8b33add 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -118,6 +118,8 @@ References .. [Diggelmann] `Automatic spike sorting for high-density microelectrode arrays. 2018. `_ +.. [Fabre] `Bombcell: automated curation and cell classification of spike-sorted electrophysiology data. 2023. ` + .. [Garcia2024] `A Modular Implementation to Handle and Benchmark Drift Correction for High-Density Extracellular Recordings. 2024. `_ .. [Garcia2022] `How Do Spike Collisions Affect Spike Sorting Performance? `_ diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index bc0a1871af..6092a0a275 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -5,7 +5,7 @@ import numpy as np -from spikeinterface.core.base import BaseExtractor +from spikeinterface.core.base import BaseExtractor, unit_period_dtype from spikeinterface.core.basesorting import BaseSorting from spikeinterface.core.numpyextractors import NumpySorting @@ -264,6 +264,31 @@ def select_sorting_periods_mask(sorting: BaseSorting, periods): return keep_mask +def cast_periods_to_unit_period_dtype(periods): + if not periods.dtype == unit_period_dtype: + if periods.ndim != 2 or periods.shape[1] != 4: + raise ValueError( + "If periods is not of dtype unit_period_dtype, it must be a 2D array with shape (num_periods, 4)" + ) + warnings.warn( + "periods is not of dtype unit_period_dtype. Assuming fields are in order: " + "(segment_index, start_sample_index, end_sample_index, unit_index).", + UserWarning, + ) + # convert to structured array + periods_converted = np.empty(periods.shape[0], dtype=unit_period_dtype) + periods_converted["segment_index"] = periods[:, 0] + periods_converted["start_sample_index"] = periods[:, 1] + periods_converted["end_sample_index"] = periods[:, 2] + periods_converted["unit_index"] = periods[:, 3] + periods = periods_converted + else: + required = set(np.dtype(unit_period_dtype).names) + if not required.issubset(periods.dtype.names): + raise ValueError(f"Period must have the following fields: {required}") + return periods + + def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: """ Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. @@ -282,33 +307,12 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: A new sorting object with only samples between start_sample_index and end_sample_index for the given segment_index. """ - from spikeinterface.core.base import unit_period_dtype from spikeinterface.core.numpyextractors import NumpySorting if periods is not None: if not isinstance(periods, np.ndarray): raise ValueError("periods must be a numpy array") - if not periods.dtype == unit_period_dtype: - if periods.ndim != 2 or periods.shape[1] != 4: - raise ValueError( - "If periods is not of dtype unit_period_dtype, it must be a 2D array with shape (num_periods, 4)" - ) - warnings.warn( - "periods is not of dtype unit_period_dtype. Assuming fields are in order: " - "(segment_index, start_sample_index, end_sample_index, unit_index).", - UserWarning, - ) - # convert to structured array - periods_converted = np.empty(periods.shape[0], dtype=unit_period_dtype) - periods_converted["segment_index"] = periods[:, 0] - periods_converted["start_sample_index"] = periods[:, 1] - periods_converted["end_sample_index"] = periods[:, 2] - periods_converted["unit_index"] = periods[:, 3] - periods = periods_converted - - required = set(np.dtype(unit_period_dtype).names) - if not required.issubset(periods.dtype.names): - raise ValueError(f"Period must have the following fields: {required}") + periods = cast_periods_to_unit_period_dtype(periods) spike_vector = sorting.to_spike_vector() keep_mask = select_sorting_periods_mask(sorting, periods) @@ -387,10 +391,9 @@ def apply_merges_to_sorting( all_unit_ids = list(all_unit_ids) num_seg = sorting.get_num_segments() - seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) - segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] + segment_slices = sorting._get_spike_vector_segment_slices() - # using this function vaoid to use the mask approach and simplify a lot the algo + # using this function avoids to use the mask approach and simplify a lot the algo spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) @@ -904,3 +907,60 @@ def _get_ids_after_splitting(old_unit_ids, split_units, new_unit_ids): all_unit_ids.remove(split_unit) all_unit_ids.extend(split_new_units) return np.array(all_unit_ids, dtype=dtype) + + +def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, keep_old_unit_ids=None): + """ + Remap the "unit_index" field in a spike vector (or period vector) according to new unit ids. + + This is useful for instance when you: + * select unit and recompute quickly the "unit_index" in the spike vector + * merging/spliting periods or spikes and update the "unit_index" in the vector + + + Parameters + ---------- + vector : numpy.array + The spike vector with a "unit_index" field. + all_old_unit_ids : numpy.array + The array of all old unit ids. + all_new_unit_ids : list + The list of all new unit ids. + keep_old_unit_ids : list | None, default: None + The list of old unit ids to keep. If None, all old unit ids are kept. + This is useful when some units are merged or split during curation, + since we don't want to keep them in the remapping + return + """ + all_old_unit_ids = np.asarray(all_old_unit_ids) + all_new_unit_ids = np.asarray(all_new_unit_ids) + assert ( + all_old_unit_ids.size == np.unique(all_old_unit_ids).size + ), "remap_unit_indices_in_vector: all_old_unit_ids not unique" + assert ( + all_new_unit_ids.size == np.unique(all_new_unit_ids).size + ), "remap_unit_indices_in_vector: all_new_unit_ids not unique" + + if keep_old_unit_ids is None: + keep_old_unit_ids = all_old_unit_ids + + # this mask has shape all_old_unit_ids.shape + mask_keep_unit = np.isin(all_old_unit_ids, keep_old_unit_ids) & np.isin(all_old_unit_ids, all_new_unit_ids) + + all_new_unit_ids = list(all_new_unit_ids) + mapping = np.zeros(all_old_unit_ids.size, dtype=int) + mapping[:] = -1 + # keep = np.zeros(all_old_unit_ids.size, dtype=bool) + for old_unit_ind, old_unit_id in enumerate(all_old_unit_ids): + if not mask_keep_unit[old_unit_ind]: + continue + new_unit_index = all_new_unit_ids.index(old_unit_id) + mapping[old_unit_ind] = new_unit_index + # keep[old_unit_ind] = True + + # this mask has shape vector.shape + keep_mask_vector = mask_keep_unit[vector["unit_index"]] + new_vector = vector[keep_mask_vector] + new_vector["unit_index"] = mapping[new_vector["unit_index"]] + + return new_vector, keep_mask_vector diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index e078f71ed4..8de45210cd 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1136,7 +1136,11 @@ def _save_or_select_or_merge_or_split( raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") recompute_dict = {} - for extension_name, extension in self.extensions.items(): + extensions_to_compute = _sort_extensions_by_dependency( + {ext.extension_name: ext.params for ext in self.extensions.values()} + ) + for extension_name in extensions_to_compute: + extension = self.extensions[extension_name] if merge_unit_groups is None and split_units is None: # copy full or select new_sorting_analyzer.extensions[extension_name] = extension.copy( @@ -2471,6 +2475,18 @@ def get_any_dependencies(cls, **params): all_dependencies = list(chain.from_iterable([dep.split("|") for dep in all_dependencies])) return all_dependencies + @classmethod + def get_default_params(cls): + """ + Get the default params for the extension. + + Returns + ------- + default_params : dict + The default parameters for the extension. + """ + return get_default_analyzer_extension_params(cls.extension_name) + def load_run_info(self): run_info = None if self.format == "binary_folder": @@ -2667,8 +2683,9 @@ def _save_data(self): extension_folder = self._get_binary_extension_folder() for ext_data_name, ext_data in self.data.items(): if isinstance(ext_data, dict): + ext_data_ = check_json(ext_data) with (extension_folder / f"{ext_data_name}.json").open("w") as f: - json.dump(ext_data, f) + json.dump(ext_data_, f) elif isinstance(ext_data, np.ndarray): data_file = extension_folder / f"{ext_data_name}.npy" if isinstance(ext_data, np.memmap) and data_file.exists(): @@ -2698,10 +2715,12 @@ def _save_data(self): for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] - if isinstance(ext_data, dict): + if isinstance(ext_data, (dict, list)): + ext_data_ = check_json(ext_data) extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() + name=ext_data_name, data=np.array([ext_data_], dtype=object), object_codec=numcodecs.JSON() ) + extension_group[ext_data_name].attrs["dict"] = True elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): @@ -2884,6 +2903,7 @@ def set_data(self, ext_data_name, ext_data): "spike_locations": "spikeinterface.postprocessing", "template_similarity": "spikeinterface.postprocessing", "unit_locations": "spikeinterface.postprocessing", + "valid_unit_periods": "spikeinterface.postprocessing", # from metrics "quality_metrics": "spikeinterface.metrics", "template_metrics": "spikeinterface.metrics", diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 87ea5e53f6..4194f459b3 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -12,7 +12,9 @@ apply_merges_to_sorting, _get_ids_after_merging, generate_unit_ids_for_merge_group, + remap_unit_indices_in_vector, ) +from spikeinterface.core.base import minimum_spike_dtype @pytest.mark.skipif( @@ -163,11 +165,68 @@ def test_generate_unit_ids_for_merge_group(): assert np.array_equal(new_unit_ids, ["0-5", "9-15"]) +def test_remap_unit_indices_in_vector(): + + unit_ids = ["a", "b", "c", "d", "e"] + n_spikes = 20 + n_units = len(unit_ids) + + spikes = np.zeros(n_spikes, dtype=minimum_spike_dtype) + spikes["unit_index"] = np.arange(n_spikes) % n_units + # the sample should remain the original unit_index after transform + spikes["sample_index"] = np.arange(n_spikes) % n_units + # print(spikes) + + # remove some units + # so 0->0, 2->1, 4->2 + new_unit_ids = ["a", "c", "e"] + new_spikes, mask = remap_unit_indices_in_vector(spikes, unit_ids, new_unit_ids, keep_old_unit_ids=None) + assert np.all(np.isin(new_spikes["unit_index"], [0, 1, 2])) + assert new_spikes.size == n_spikes * len(new_unit_ids) // n_units + # print(new_spikes) + + # rename units in reverse order + # so 0->4, 1->3, 2->2, 3->1, 4->0 + new_unit_ids = ["e", "d", "c", "b", "a"] + new_spikes, mask = remap_unit_indices_in_vector(spikes, unit_ids, new_unit_ids, keep_old_unit_ids=None) + assert new_spikes.size == spikes.size + assert np.all(new_spikes["unit_index"] == 4 - new_spikes["sample_index"]) + # print(new_spikes) + + # add some new units + # vector unchanged + new_unit_ids = ["a", "b", "c", "d", "e", "f", "g"] + new_spikes, mask = remap_unit_indices_in_vector(spikes, unit_ids, new_unit_ids, keep_old_unit_ids=None) + assert np.array_equal(new_spikes, spikes) + # print(new_spikes) + + # add some + remove some + # so 0->0, 2->1, 4->2 + new_unit_ids = ["a", "c", "e", "f", "g"] + new_spikes, mask = remap_unit_indices_in_vector(spikes, unit_ids, new_unit_ids, keep_old_unit_ids=None) + assert np.all(np.isin(new_spikes["unit_index"], [0, 1, 2])) + assert new_spikes.size == n_spikes * 3 // n_units + # print(new_spikes) + + # remove one unit which is also in the new unit set + # the unit_id="e" (index=4) will not be in new set + new_unit_ids = ["a", "b", "c", "d", "e"] + keep_old_unit_ids = ["a", "b", "c", "d"] + new_spikes, mask = remap_unit_indices_in_vector(spikes, unit_ids, new_unit_ids, keep_old_unit_ids=keep_old_unit_ids) + assert np.all(np.isin(new_spikes["unit_index"], [0, 1, 2, 3])) + assert new_spikes.size == n_spikes * 4 // n_units + target_mask = np.ones(spikes.size, dtype=bool) + target_mask[4::5] = False + assert np.array_equal(mask, target_mask) + + if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - test_random_spikes_selection() + # test_random_spikes_selection() # test_apply_merges_to_sorting() # test_get_ids_after_merging() # test_generate_unit_ids_for_merge_group() + + test_remap_unit_indices_in_vector() diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index f0909fbc8a..70dfab66b3 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -48,25 +48,24 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None): self._kwargs = dict(parent_sorting=parent_sorting, unit_ids=unit_ids, renamed_unit_ids=renamed_unit_ids) def _compute_and_cache_spike_vector(self) -> None: + from spikeinterface.core.sorting_tools import remap_unit_indices_in_vector + if self._parent_sorting._cached_spike_vector is None: self._parent_sorting._compute_and_cache_spike_vector() if self._parent_sorting._cached_spike_vector is None: return - parent_spike_vector = self._parent_sorting._cached_spike_vector - parent_unit_indices = self._parent_sorting.ids_to_indices(self._unit_ids) - sort_indices = np.argsort(parent_unit_indices) - mask = np.isin(parent_spike_vector["unit_index"], parent_unit_indices) - spike_vector = np.array( - parent_spike_vector[mask] - ) # np.array() necessary to fix 'read-only' crash with memmaps. - indices = np.searchsorted( - parent_unit_indices, spike_vector["unit_index"], sorter=sort_indices - ) # Trick to make sure that the new indices are correct. - spike_vector["unit_index"] = np.arange(len(parent_unit_indices))[sort_indices][indices] - - self._cached_spike_vector = spike_vector + spike_vector, _ = remap_unit_indices_in_vector( + vector=self._parent_sorting._cached_spike_vector, + all_old_unit_ids=self._parent_sorting.unit_ids, + all_new_unit_ids=self._unit_ids, + ) + # lexsort by segment_index, sample_index, unit_index + sort_indices = np.lexsort( + (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + ) + self._cached_spike_vector = spike_vector[sort_indices] class UnitsSelectionSortingSegment(BaseSortingSegment): diff --git a/src/spikeinterface/metrics/quality/__init__.py b/src/spikeinterface/metrics/quality/__init__.py index 1edcd9221f..f91ed6eefc 100644 --- a/src/spikeinterface/metrics/quality/__init__.py +++ b/src/spikeinterface/metrics/quality/__init__.py @@ -20,4 +20,5 @@ compute_sliding_rp_violations, compute_sd_ratio, compute_synchrony_metrics, + compute_refrac_period_violations, ) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index e5cc2aa323..5476aa405a 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -49,6 +49,13 @@ class ComputeQualityMetrics(BaseMetricExtension): need_backward_compatibility_on_load = True metric_list = misc_metrics_list + pca_metrics_list + @classmethod + def get_required_dependencies(cls, **params): + if params.get("use_valid_periods", False): + return ["valid_unit_periods"] + else: + return [] + def _handle_backward_compatibility_on_load(self): # For backwards compatibility - this renames qm_params as metric_params if (qm_params := self.params.get("qm_params")) is not None: @@ -70,6 +77,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + use_valid_periods=False, periods=None, # common extension kwargs peak_sign=None, @@ -86,6 +94,11 @@ def _set_params( pc_metric_names = [m.metric_name for m in pca_metrics_list] metric_names = [m for m in metric_names if m not in pc_metric_names] + if use_valid_periods: + if periods is not None: + raise ValueError("If use_valid_periods is True, periods should not be provided.") + periods = self.sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy") + return super()._set_params( metric_names=metric_names, metric_params=metric_params, diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index 2e87002018..dfd47c4df9 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from spikeinterface.core import ( @@ -172,6 +171,76 @@ def test_empty_units(sorting_analyzer_simple): assert all_nans or all_zeros or all_neg_ones, f"Column {col} failed the empty unit test" +def test_quality_metrics_with_periods(): + """ + Test that quality metrics can be computed using valid unit periods. + """ + from spikeinterface.core.base import unit_period_dtype + + recording, sorting = generate_ground_truth_recording() + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") + + # compute dependencies + sorting_analyzer.compute(["random_spikes", "templates", "amplitude_scalings", "valid_unit_periods"], **job_kwargs) + print(sorting_analyzer) + + # compute quality metrics using valid periods + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + seed=2205, + ) + print(metrics) + + # test with external periods: 1 period per segment from 10 to 90% of recording + num_segments = recording.get_num_segments() + periods = np.zeros(len(sorting.unit_ids) * num_segments, dtype=unit_period_dtype) + for i, unit_id in enumerate(sorting.unit_ids): + unit_index = sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = int(num_samples * 0.1) + period_end = int(num_samples * 0.9) + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_end + periods[idx]["segment_index"] = segment_index + + metrics_ext_periods = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=False, + periods=periods, + seed=2205, + ) + + # test failure when both periods and use_valid_periods are set + with pytest.raises(ValueError): + compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + periods=periods, + seed=2205, + ) + + # test failure if use valid_periods is True but valid_unit_periods extension is missing + sorting_analyzer.delete_extension("valid_unit_periods") + with pytest.raises(AssertionError): + compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + seed=2205, + ) + + if __name__ == "__main__": sorting_analyzer = get_sorting_analyzer() diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index dca9711ccd..555c9a5d3b 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -44,3 +44,8 @@ ComputeTemplateMetrics, compute_template_metrics, ) + +from .valid_unit_periods import ( + ComputeValidUnitPeriods, + compute_valid_unit_periods, +) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 8f3ffe0617..473798fe7c 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -127,7 +127,11 @@ def _get_pipeline_nodes(self): sparsity = self.params["sparsity"] else: if self.params["max_dense_channels"] is not None: - assert recording.get_num_channels() <= self.params["max_dense_channels"], "" + assert recording.get_num_channels() <= self.params["max_dense_channels"], ( + "Sparsity must be provided when the number of channels is " + f"greater than {self.params['max_dense_channels']}. Alternatively, set max_dense_channels to None " + "to compute amplitude scalings using dense waveforms." + ) sparsity = ChannelSparsity.create_dense(self.sorting_analyzer) sparsity_mask = sparsity.mask diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index ce3d1cd4a9..aef44e7f56 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -153,7 +153,7 @@ def _merge_extension_data( # check if it is mapped to itself if old_unit == new_unit_id: old_to_new_unit_index_map[old_unit_index] = new_unit_index - # or to a unit_id outwith the old ones + # or to a unit_id without the old ones elif new_unit_id not in self.sorting_analyzer.unit_ids: if new_unit_index not in old_to_new_unit_index_map.values(): old_to_new_unit_index_map[old_unit_index] = new_unit_index @@ -192,7 +192,7 @@ def _merge_extension_data( return new_data def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - # TODO: for now we just copy + # for splits, we need to recompute correlograms new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) new_data = dict(ccgs=new_ccgs, bins=new_bins) return new_data diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 69b5b7fb0b..e5586e79c7 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -4,9 +4,13 @@ import shutil import numpy as np -from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.core import estimate_sparsity +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, + load_sorting_analyzer, + estimate_sparsity, +) +from spikeinterface.core.sortinganalyzer import get_extension_class extensions_which_allow_unit_ids = ["unit_locations"] @@ -97,7 +101,19 @@ def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=Non return sorting_analyzer - def _prepare_sorting_analyzer(self, format, sparse, extension_class): + def _compute_extensions_recursively(self, sorting_analyzer, extension_class, params): + # compute dependencies of the extension class with default params + dependencies = extension_class.get_required_dependencies(**params) + for dependency_name in dependencies: + if "|" in dependency_name: + dependency_name = dependency_name.split("|")[0] + if not sorting_analyzer.has_extension(dependency_name): + # compute dependencies of the dependency + self._compute_extensions_recursively(sorting_analyzer, get_extension_class(dependency_name), {}) + # compute the dependency itself + sorting_analyzer.compute(dependency_name) + + def _prepare_sorting_analyzer(self, format, sparse, extension_class, extension_params=None): # prepare a SortingAnalyzer object with depencies already computed sparsity_ = self.sparsity if sparse else None sorting_analyzer = self.get_sorting_analyzer( @@ -105,10 +121,11 @@ def _prepare_sorting_analyzer(self, format, sparse, extension_class): ) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=20, seed=2205) - for dependency_name in extension_class.depend_on: - if "|" in dependency_name: - dependency_name = dependency_name.split("|")[0] - sorting_analyzer.compute(dependency_name) + # default params for dependencies + params = sorting_analyzer.get_default_extension_params(extension_class.extension_name) + if extension_params is not None: + params.update(extension_params) + self._compute_extensions_recursively(sorting_analyzer, extension_class, params) return sorting_analyzer @@ -128,7 +145,7 @@ def _check_one(self, sorting_analyzer, extension_class, params): ext = sorting_analyzer.compute(extension_class.extension_name, **params, **job_kwargs) assert len(ext.data) > 0 main_data = ext.get_data() - assert len(main_data) > 0 + assert main_data is not None ext = sorting_analyzer.get_extension(extension_class.extension_name) assert ext is not None @@ -159,7 +176,11 @@ def _check_one(self, sorting_analyzer, extension_class, params): ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name) for ext_data_name, ext_data_loaded in ext_loaded.data.items(): if isinstance(ext_data_loaded, np.ndarray): - assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) + if len(ext_data_loaded) > 0 and isinstance(ext_data_loaded[0], dict): + for i in range(len(ext_data_loaded)): + assert np.array_equal(np.array(ext.data[ext_data_name][i]), np.array(ext_data_loaded[i])) + else: + assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) elif isinstance(ext_data_loaded, pd.DataFrame): # skip nan values for col in ext_data_loaded.columns: @@ -182,5 +203,7 @@ def run_extension_tests(self, extension_class, params): for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): print("sparse", sparse, format) - sorting_analyzer = self._prepare_sorting_analyzer(format, sparse, extension_class) + sorting_analyzer = self._prepare_sorting_analyzer( + format, sparse, extension_class, extension_params=params + ) self._check_one(sorting_analyzer, extension_class, params) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index fc9d3643bc..77bff7a3d8 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -99,9 +99,9 @@ def test_get_projections(self, sparse): some_unit_ids = sorting_analyzer.unit_ids[::2] some_channel_ids = sorting_analyzer.channel_ids[::2] - random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() - all_num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit() - unit_ids_num_spikes = np.sum(all_num_spikes[unit_id] for unit_id in some_unit_ids) + random_spikes_ext = sorting_analyzer.get_extension("random_spikes") + random_spikes_indices = random_spikes_ext.get_data() + unit_ids_num_random_spikes = np.sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) @@ -113,7 +113,7 @@ def test_get_projections(self, sparse): # this should be some spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert spike_unit_index.shape[0] == unit_ids_num_random_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans assert 1 not in spike_unit_index @@ -123,7 +123,7 @@ def test_get_projections(self, sparse): channel_ids=some_channel_ids, unit_ids=some_unit_ids ) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert spike_unit_index.shape[0] == unit_ids_num_random_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == some_channel_ids.size assert 1 not in spike_unit_index diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py new file mode 100644 index 0000000000..6d34264eac --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -0,0 +1,131 @@ +import pytest +import numpy as np + +from spikeinterface.core.base import unit_period_dtype +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputeValidUnitPeriods + + +class TestComputeValidUnitPeriods(AnalyzerExtensionCommonTestSuite): + + @pytest.mark.parametrize( + "params", + [ + dict(period_mode="absolute", period_duration_s_absolute=1.1, minimum_valid_period_duration=1.0), + dict(period_mode="relative", period_target_num_spikes=30, minimum_valid_period_duration=1.0), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeValidUnitPeriods, params) + + def test_user_defined_periods(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods = np.zeros(len(unit_ids) * num_segments, dtype=unit_period_dtype) + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_start + period_duration + periods[idx]["segment_index"] = segment_index + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods + ) + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods, + minimum_valid_period_duration=1, + ) + # check that valid periods correspond to user defined periods + ext_periods = ext.get_data(outputs="numpy") + np.testing.assert_array_equal(ext_periods, periods) + + def test_user_defined_periods_as_arrays(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods_array = np.zeros((len(unit_ids) * num_segments, 4), dtype="int64") + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods_array[idx, 0] = segment_index + periods_array[idx, 1] = period_start + periods_array[idx, 2] = period_start + period_duration + periods_array[idx, 3] = unit_index + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods + ) + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods_array, + minimum_valid_period_duration=1, + ) + # check that valid periods correspond to user defined periods + ext_periods = ext.get_data(outputs="numpy") + ext_periods = np.column_stack([ext_periods[field] for field in ext_periods.dtype.names]) + np.testing.assert_array_equal(ext_periods, periods_array) + + # test that dropping segment_index raises because multi-segment + with pytest.raises(ValueError): + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods_array[:, 1:4], # drop segment_index + minimum_valid_period_duration=1, + ) + + def test_combined_periods(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods = np.zeros(len(unit_ids) * num_segments, dtype=unit_period_dtype) + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_start + period_duration + periods[idx]["segment_index"] = segment_index + + unit_valid_periods_params = dict( + method="combined", + user_defined_periods=periods, + period_mode="absolute", + period_duration_s_absolute=1.0, + minimum_valid_period_duration=1, + ) + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods, extension_params=unit_valid_periods_params + ) + ext = sorting_analyzer.compute(ComputeValidUnitPeriods.extension_name, **unit_valid_periods_params) + # check that valid periods correspond to intersection of auto-computed and user defined periods + ext_periods = ext.get_data(outputs="numpy") + assert len(ext_periods) <= len(periods) # should be less or equal than user defined ones diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py new file mode 100644 index 0000000000..070387d034 --- /dev/null +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -0,0 +1,802 @@ +from __future__ import annotations + +import importlib.util +import warnings + +import numpy as np +from typing import Optional +from copy import deepcopy + +from concurrent.futures import ProcessPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +from spikeinterface.core.base import unit_period_dtype +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.sorting_tools import cast_periods_to_unit_period_dtype, remap_unit_indices_in_vector +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from spikeinterface.metrics.spiketrain import compute_firing_rates + +numba_spec = importlib.util.find_spec("numba") +if numba_spec is not None: + HAVE_NUMBA = True +else: + HAVE_NUMBA = False + + +class ComputeValidUnitPeriods(AnalyzerExtension): + """Compute valid unit periods for units. + By default, the extension uses the "false_positives_and_negatives" method, which computes amplitude cutoffs + (false negative rate) and refractory period violations (false positive rate) over chunks of data + to estimate valid periods. External user-defined periods can also be provided. + + Parameters + ---------- + method : "false_positives_and_negatives" | "user_defined" | "combined", default: "false_positives_and_negatives" + Strategy for identifying good periods for each unit. If "false_positives_and_negatives", uses + amplitude cutoff (false negative spike rate) and refractory period violations (false positive spike rate) + to estimate good periods (as periods with fn_rate 0 or period_target_num_spikes > 0 + ), "Either period_duration_s_absolute or period_target_num_spikes must be positive." + assert isinstance(period_target_num_spikes, (int)), "period_target_num_spikes must be an integer." + + # user_defined_periods formatting + self.user_defined_periods = None + if user_defined_periods is not None: + try: + user_defined_periods = np.asarray(user_defined_periods) + except Exception as e: + raise ValueError( + ( + "user_defined_periods must be some (n_periods, 3) [unit, good_period_start, good_period_end] " + "or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end] structure convertible to a numpy array" + ) + ) + + user_defined_periods = cast_periods_to_unit_period_dtype(user_defined_periods) + + # assert that user-defined periods are not too short + fs = self.sorting_analyzer.sampling_frequency + durations = user_defined_periods["end_sample_index"] - user_defined_periods["start_sample_index"] + min_duration_samples = int(minimum_valid_period_duration * fs) + if np.any(durations < min_duration_samples): + raise ValueError( + f"All user-defined periods must be at least {minimum_valid_period_duration} seconds long." + ) + self.user_defined_periods = user_defined_periods + + params = dict( + method=method, + period_duration_s_absolute=period_duration_s_absolute, + period_target_num_spikes=period_target_num_spikes, + period_mode=period_mode, + relative_margin_size=relative_margin_size, + min_num_periods_relative=min_num_periods_relative, + fp_threshold=fp_threshold, + fn_threshold=fn_threshold, + minimum_n_spikes=minimum_n_spikes, + minimum_valid_period_duration=minimum_valid_period_duration, + refractory_period_ms=refractory_period_ms, + censored_period_ms=censored_period_ms, + num_histogram_bins=num_histogram_bins, + histogram_smoothing_value=histogram_smoothing_value, + amplitudes_bins_min_ratio=amplitudes_bins_min_ratio, + ) + + return params + + def _select_extension_data(self, unit_ids): + new_extension_data = {} + new_valid_periods, _ = remap_unit_indices_in_vector( + self.data["valid_unit_periods"], self.sorting_analyzer.unit_ids, unit_ids + ) + new_extension_data["valid_unit_periods"] = new_valid_periods + all_periods = self.data.get("all_periods", None) + if all_periods is not None: + new_all_periods, keep_mask = remap_unit_indices_in_vector( + vector=all_periods, all_old_unit_ids=self.sorting_analyzer.unit_ids, all_new_unit_ids=unit_ids + ) + new_extension_data["all_periods"] = new_all_periods + new_extension_data["fps"] = self.data["fps"][keep_mask] + new_extension_data["fns"] = self.data["fns"][keep_mask] + + return new_extension_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + new_extension_data = {} + # remove data of merged units + merged_unit_ids = np.concatenate(merge_unit_groups) + untouched_unit_ids = [u for u in self.sorting_analyzer.unit_ids if u not in merged_unit_ids] + new_valid_periods, _ = remap_unit_indices_in_vector( + vector=self.data["valid_unit_periods"], + all_old_unit_ids=self.sorting_analyzer.unit_ids, + all_new_unit_ids=new_sorting_analyzer.unit_ids, + keep_old_unit_ids=untouched_unit_ids, + ) + + if self.params["method"] in ("false_positives_and_negatives", "combined"): + # need to recompute for merged units + recompute = True + else: + # in case of user-defined periods, just merge periods + recompute = False + + if recompute: + new_all_periods, keep_all_periods_mask = remap_unit_indices_in_vector( + vector=self.data["all_periods"], + all_old_unit_ids=self.sorting_analyzer.unit_ids, + all_new_unit_ids=new_sorting_analyzer.unit_ids, + keep_old_unit_ids=untouched_unit_ids, + ) + new_fps = self.data["fps"][keep_all_periods_mask] + new_fns = self.data["fns"][keep_all_periods_mask] + + # recompute for merged units + valid_periods_merged, all_periods_merged, fps_merged, fns_merged = self._compute_valid_periods( + new_sorting_analyzer, + unit_ids=new_unit_ids, + ) + + new_valid_periods = np.concatenate((new_valid_periods, valid_periods_merged), axis=0) + new_all_periods = np.concatenate((new_all_periods, all_periods_merged), axis=0) + new_fps = np.concatenate((new_fps, fps_merged), axis=0) + new_fns = np.concatenate((new_fns, fns_merged), axis=0) + + new_extension_data["valid_unit_periods"], _ = self._sort_periods(new_valid_periods) + new_extension_data["all_periods"], sort_indices = self._sort_periods(new_all_periods) + new_extension_data["fps"] = new_fps[sort_indices] + new_extension_data["fns"] = new_fns[sort_indices] + else: + # just merge periods + valid_periods_merged = [] + original_valid_periods = self.data["valid_unit_periods"] + for unit_ids, new_unit_id in zip(merge_unit_groups, new_unit_ids): + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + new_unit_index = new_sorting_analyzer.sorting.id_to_index(new_unit_id) + # get periods of all units to be merged + merge_mask = np.isin(original_valid_periods["unit_index"], unit_indices) + masked_periods = original_valid_periods[merge_mask] + masked_periods["unit_index"] = new_unit_index + valid_periods_merged.append(masked_periods) + + valid_periods_merged = np.concatenate(valid_periods_merged, axis=0) + # now merge with unsplit periods + new_valid_periods = np.concatenate((new_valid_periods, valid_periods_merged), axis=0) + # sort and merge + new_valid_periods = merge_overlapping_periods_across_units_and_segments(new_valid_periods) + new_valid_periods, _ = self._sort_periods(new_valid_periods) + new_extension_data["valid_unit_periods"] = new_valid_periods + + return new_extension_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_extension_data = {} + # remove data of split units + split_unit_ids = list(split_units.keys()) + untouched_unit_ids = [u for u in self.sorting_analyzer.unit_ids if u not in split_unit_ids] + new_valid_periods, _ = remap_unit_indices_in_vector( + vector=self.data["valid_unit_periods"], + all_old_unit_ids=self.sorting_analyzer.unit_ids, + all_new_unit_ids=new_sorting_analyzer.unit_ids, + keep_old_unit_ids=untouched_unit_ids, + ) + if self.params["method"] in ("false_positives_and_negatives", "combined"): + # need to recompute for split units + recompute = True + else: + # in case of user-defined periods, we can only duplicate valid periods for the split + recompute = False + + if recompute: + new_all_periods, keep_all_periods_mask = remap_unit_indices_in_vector( + vector=self.data["all_periods"], + all_old_unit_ids=self.sorting_analyzer.unit_ids, + all_new_unit_ids=new_sorting_analyzer.unit_ids, + keep_old_unit_ids=untouched_unit_ids, + ) + new_fps = self.data["fps"][keep_all_periods_mask] + new_fns = self.data["fns"][keep_all_periods_mask] + + # recompute for split units + new_unit_ids = np.concatenate(new_unit_ids) + + valid_periods_split, all_periods_split, fps_split, fns_split = self._compute_valid_periods( + new_sorting_analyzer, + unit_ids=new_unit_ids, + ) + + new_valid_periods = np.concatenate((new_valid_periods, valid_periods_split), axis=0) + new_all_periods = np.concatenate((new_all_periods, all_periods_split), axis=0) + new_fps = np.concatenate((new_fps, fps_split), axis=0) + new_fns = np.concatenate((new_fns, fns_split), axis=0) + + new_extension_data["valid_unit_periods"], _ = self._sort_periods(new_valid_periods) + new_extension_data["all_periods"], sort_indices = self._sort_periods(new_all_periods) + new_extension_data["fps"] = new_fps[sort_indices] + new_extension_data["fns"] = new_fns[sort_indices] + else: + # just duplicate periods to the split units + valid_periods_split = [] + original_valid_periods = self.data["valid_unit_periods"] + split_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(split_units) + for split_unit_id, new_unit_ids in zip(split_units, new_unit_ids): + unit_index = self.sorting_analyzer.sorting.id_to_index(split_unit_id) + new_unit_indices = new_sorting_analyzer.sorting.ids_to_indices(new_unit_ids) + split_unit_indices.append(unit_index) + # get periods of all units to be merged + masked_periods = original_valid_periods[original_valid_periods["unit_index"] == unit_index] + for new_unit_index in new_unit_indices: + _split_periods = masked_periods.copy() + _split_periods["unit_index"] = new_unit_index + valid_periods_split.append(_split_periods) + if len(masked_periods) == 0: + continue + valid_periods_split = np.concatenate(valid_periods_split, axis=0) + # now merge with unsplit periods + new_valid_periods = np.concatenate((new_valid_periods, valid_periods_split), axis=0) + # sort and merge + new_valid_periods = merge_overlapping_periods_across_units_and_segments(new_valid_periods) + new_valid_periods, _ = self._sort_periods(new_valid_periods) + new_extension_data["valid_unit_periods"] = new_valid_periods + + return new_extension_data + + def _run(self, unit_ids=None, verbose=False, **job_kwargs): + valid_unit_periods, all_periods, fps, fns = self._compute_valid_periods( + self.sorting_analyzer, + unit_ids=unit_ids, + **job_kwargs, + ) + self.data["valid_unit_periods"] = valid_unit_periods + if all_periods is not None: + self.data["all_periods"] = all_periods + if fps is not None: + self.data["fps"] = fps + if fns is not None: + self.data["fns"] = fns + + def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): + if self.params["method"] == "user_defined": + + # directly use user defined periods + return self.user_defined_periods, None, None, None + + elif self.params["method"] in ["false_positives_and_negatives", "combined"]: + + # dict: unit_id -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields + all_periods, all_periods_w_margins = compute_subperiods( + sorting_analyzer, + self.params["period_duration_s_absolute"], + self.params["period_target_num_spikes"], + self.params["period_mode"], + self.params["relative_margin_size"], + self.params["min_num_periods_relative"], + unit_ids=unit_ids, + ) + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] + mp_context = job_kwargs["mp_context"] + + # Compute fp and fn for all periods + # Process units in parallel + amp_scalings = sorting_analyzer.get_extension("amplitude_scalings") + all_amplitudes_by_unit = amp_scalings.get_data(outputs="by_unit", concatenated=False) + + init_args = (sorting_analyzer.sorting, all_amplitudes_by_unit, self.params, max_threads_per_worker) + + # Each item is one computation of fp and fn for one period and one unit + items = [(period,) for period in all_periods_w_margins] + job_name = f"computing false positives and negatives" + + # parallel + with ProcessPoolExecutor( + max_workers=n_jobs, + initializer=fp_fn_worker_init, + mp_context=mp.get_context(mp_context), + initargs=init_args, + ) as executor: + results = executor.map(fp_fn_worker_func_wrapper, items) + + if progress_bar: + results = tqdm(results, desc=f"{job_name} (workers: {n_jobs} processes)", total=len(items)) + + all_fps = np.zeros(len(all_periods)) + all_fns = np.zeros(len(all_periods)) + for i, (fp, fn) in enumerate(results): + all_fps[i] = fp + all_fns[i] = fn + + # set NaNs to 1 (they will be exluded anyways) + all_fps[np.isnan(all_fps)] = 1.0 + all_fns[np.isnan(all_fns)] = 1.0 + + valid_period_mask = (all_fps < self.params["fp_threshold"]) & (all_fns < self.params["fn_threshold"]) + valid_unit_periods = all_periods[valid_period_mask] + + # Combine with user-defined periods if provided + if self.params["method"] == "combined": + user_defined_periods = self.user_defined_periods + valid_unit_periods = np.concatenate((valid_unit_periods, user_defined_periods), axis=0) + + # Sort good periods on segment_index, unit_index, start_sample_index + valid_unit_periods, _ = self._sort_periods(valid_unit_periods) + valid_unit_periods = merge_overlapping_periods_across_units_and_segments(valid_unit_periods) + + # Remove good periods that are too short + minimum_valid_period_duration = self.params["minimum_valid_period_duration"] + min_valid_period_samples = int(minimum_valid_period_duration * sorting_analyzer.sampling_frequency) + duration_samples = valid_unit_periods["end_sample_index"] - valid_unit_periods["start_sample_index"] + valid_mask = duration_samples >= min_valid_period_samples + valid_unit_periods = valid_unit_periods[valid_mask] + + # Prepare period centers, fps, fns per unit dicts + + # Store data: here we have to make sure every dict is JSON serializable, so everything is lists + return valid_unit_periods, all_periods, all_fps, all_fns + + def get_fps_and_fns(self, unit_ids=None): + """Get false positives and false negatives per segment and unit. + + Parameters + ---------- + unit_ids : list | None + List of unit IDs to get false positives and negatives for. If None, returns for all units. + + Returns + ------- + fps : list + List (per segment) of dictionaries mapping unit IDs to lists of false positive rates. + fns : list + List (per segment) of dictionaries mapping unit IDs to lists of false negative rates. + """ + # split values by segment and units + all_periods = self.data.get("all_periods", None) + if all_periods is None: + return None, None + all_fps = self.data["fps"] + all_fns = self.data["fns"] + + if unit_ids is None: + unit_ids = self.sorting_analyzer.unit_ids + + num_segments = len(np.unique(all_periods["segment_index"])) + fps = [] + fns = [] + for segment_index in range(num_segments): + fp_in_segment = {} + fn_in_segment = {} + segment_mask = all_periods["segment_index"] == segment_index + periods_segment = all_periods[segment_mask] + fps_segment = all_fps[segment_mask] + fns_segment = all_fns[segment_mask] + for unit_id in unit_ids: + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + unit_mask = periods_segment["unit_index"] == unit_index + fp_in_segment[unit_id] = fps_segment[unit_mask] + fn_in_segment[unit_id] = fns_segment[unit_mask] + fps.append(fp_in_segment) + fns.append(fn_in_segment) + + return fps, fns + + def get_period_centers(self, unit_ids=None): + """ + Get period centers used for computing false positives and negatives. + + Parameters + ---------- + unit_ids : list | None + List of unit IDs to get period centers for. If None, returns for all units. + + Returns + ------- + period_centers : list + List (per segment) of dictionaries mapping unit IDs to lists of period center sample indices. + """ + all_periods = self.data.get("all_periods", None) + if all_periods is None: + return None + if unit_ids is None: + unit_ids = self.sorting_analyzer.unit_ids + + num_segments = len(np.unique(all_periods["segment_index"])) + all_period_centers = [] + for segment_index in range(num_segments): + period_centers = {} + periods_segment = all_periods[all_periods["segment_index"] == segment_index] + for unit_id in unit_ids: + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + periods_unit = periods_segment[periods_segment["unit_index"] == unit_index] + period_samples = (periods_unit["start_sample_index"] + periods_unit["end_sample_index"]) // 2 + # period_samples are the same for all bins (per unit), so we can just take the first one + period_centers[unit_id] = periods_unit["start_sample_index"] + period_samples[0] + all_period_centers.append(period_centers) + return all_period_centers + + def _get_data(self, outputs: str = "by_unit"): + """ + Return extension data. If the extension computes more than one `nodepipeline_variables`, + the `return_data_name` is used to specify which one to return. + + Parameters + ---------- + outputs : "numpy" | "by_unit", default: "by_unit" + How to return the data. + + Returns + ------- + numpy.ndarray | list + The periods in numpy or dictionary by unit format, depending on `outputs`. + If "numpy", returns an array of dtype unit_period_dtype with columns: + unit_index, segment_index, start_sample_index, end_sample_index. + If "by_unit", returns a list (per segment) of dictionaries mapping unit IDs to lists of + (start_sample_index, end_sample_index) tuples. + """ + if outputs == "numpy": + good_periods = self.data["valid_unit_periods"].copy() + else: + # by_unit + unit_ids = self.sorting_analyzer.unit_ids + good_periods = [] + good_periods_array = self.data["valid_unit_periods"] + for segment_index in range(self.sorting_analyzer.get_num_segments()): + segment_mask = good_periods_array["segment_index"] == segment_index + periods_dict = {} + for unit_index in unit_ids: + periods_dict[unit_index] = [] + unit_mask = good_periods_array["unit_index"] == unit_index + good_periods_unit_segment = good_periods_array[segment_mask & unit_mask] + for start, end in good_periods_unit_segment[["start_sample_index", "end_sample_index"]]: + periods_dict[unit_index].append((start, end)) + good_periods.append(periods_dict) + + return good_periods + + def _sort_periods(self, periods): + sort_idx = np.lexsort((periods["start_sample_index"], periods["unit_index"], periods["segment_index"])) + sorted_periods = periods[sort_idx] + return sorted_periods, sort_idx + + +def compute_subperiods( + sorting_analyzer, + period_duration_s_absolute: float = 10, + period_target_num_spikes: int = 1000, + period_mode: str = "absolute", + relative_margin_size: float = 1.0, + min_num_periods_relative: int = 5, + unit_ids: Optional[list] = None, +) -> dict: + """ + Computes subperiods per unit based on specified size mode. + + Returns + ------- + all_subperiods : dict + Dictionary mapping unit IDs to lists of subperiods (arrays of dtype unit_period_dtype). + """ + sorting = sorting_analyzer.sorting + fs = sorting.sampling_frequency + if unit_ids is None: + unit_ids = sorting.unit_ids + + if period_mode == "absolute": + period_sizes_samples = {u: np.round(period_duration_s_absolute * fs).astype(int) for u in unit_ids} + else: # relative + mean_firing_rates = compute_firing_rates(sorting_analyzer, unit_ids) + period_sizes_samples = { + u: np.round((period_target_num_spikes / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids + } + margin_sizes_samples = {u: np.round(relative_margin_size * period_sizes_samples[u]).astype(int) for u in unit_ids} + + all_subperiods = [] + all_subperiods_w_margins = [] + for segment_index in range(sorting.get_num_segments()): + n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + period_size_samples = period_sizes_samples[unit_id] + margin_size_samples = margin_sizes_samples[unit_id] + # We round the number of subperiods to ensure coverage of the entire recording + # the end of the last period is then clipped or extended to the end of the recording + n_subperiods = round(n_samples / period_size_samples) + if period_mode == "relative" and n_subperiods < min_num_periods_relative: + n_subperiods = min_num_periods_relative # at least min_num_periods_relative subperiods + period_size_samples = n_samples // n_subperiods + margin_size_samples = int(relative_margin_size * period_size_samples) + + # we generate periods starting from 0 up to n_samples, with and without margins, and period centers + starts = np.arange(0, n_samples, period_size_samples) + periods_for_unit = np.zeros(len(starts), dtype=unit_period_dtype) + periods_for_unit_w_margins = np.zeros(len(starts), dtype=unit_period_dtype) + for i, start in enumerate(starts): + end = min(start + period_size_samples, n_samples) + ext_start = max(0, start - margin_size_samples) + ext_end = min(n_samples, end + margin_size_samples) + periods_for_unit[i]["segment_index"] = segment_index + periods_for_unit[i]["start_sample_index"] = start + periods_for_unit[i]["end_sample_index"] = end + periods_for_unit[i]["unit_index"] = unit_index + periods_for_unit_w_margins[i]["segment_index"] = segment_index + periods_for_unit_w_margins[i]["start_sample_index"] = ext_start + periods_for_unit_w_margins[i]["end_sample_index"] = ext_end + periods_for_unit_w_margins[i]["unit_index"] = unit_index + + all_subperiods.append(periods_for_unit) + all_subperiods_w_margins.append(periods_for_unit_w_margins) + return np.concatenate(all_subperiods), np.concatenate(all_subperiods_w_margins) + + +def merge_overlapping_periods_for_unit(subperiods): + """ + Merges overlapping periods for a single unit and segment. + + Parameters + ---------- + subperiods : np.ndarray + Array of dtype unit_period_dtype containing periods to be merged. + + Returns + ------- + merged_periods : np.ndarray + Array of dtype unit_period_dtype containing merged periods. + """ + segment_indices = np.unique(subperiods["segment_index"]) + assert len(segment_indices) == 1, "Subperiods must belong to the same segment to be merged." + segment_index = segment_indices[0] + unit_indices = np.unique(subperiods["unit_index"]) + assert len(unit_indices) == 1, "Subperiods must belong to the same unit to be merged." + unit_index = unit_indices[0] + + # Sort subperiods by start time for interval merging + sort_idx = np.argsort(subperiods["start_sample_index"]) + sorted_subperiods = subperiods[sort_idx] + + # Merge overlapping/adjacent intervals + merged_starts = [sorted_subperiods[0]["start_sample_index"]] + merged_ends = [sorted_subperiods[0]["end_sample_index"]] + + for i in range(1, len(sorted_subperiods)): + current_start = sorted_subperiods[i]["start_sample_index"] + current_end = sorted_subperiods[i]["end_sample_index"] + + # Merge if overlapping or contiguous (end >= start) + if current_start <= merged_ends[-1]: + merged_ends[-1] = max(merged_ends[-1], current_end) + else: + merged_starts.append(current_start) + merged_ends.append(current_end) + + # Construct output array + n_periods = len(merged_starts) + merged_periods = np.zeros(n_periods, dtype=unit_period_dtype) + merged_periods["segment_index"] = segment_index + merged_periods["start_sample_index"] = merged_starts + merged_periods["end_sample_index"] = merged_ends + merged_periods["unit_index"] = unit_index + + return merged_periods + + +def merge_overlapping_periods_across_units_and_segments(periods): + """ + Merges overlapping periods across all units and segments. + + Parameters + ---------- + periods : np.ndarray + Array of dtype unit_period_dtype containing periods to be merged. + + Returns + ------- + merged_periods : np.ndarray + Array of dtype unit_period_dtype containing merged periods. + """ + segments = np.unique(periods["segment_index"]) + units = np.unique(periods["unit_index"]) + merged_periods = [] + for segment_index in segments: + periods_per_segment = periods[periods["segment_index"] == segment_index] + for unit_index in units: + masked_periods = periods_per_segment[(periods_per_segment["unit_index"] == unit_index)] + if len(masked_periods) == 0: + continue + _merged_periods = merge_overlapping_periods_for_unit(masked_periods) + merged_periods.append(_merged_periods) + if len(merged_periods) == 0: + merged_periods = np.array([], dtype=unit_period_dtype) + else: + merged_periods = np.concatenate(merged_periods, axis=0) + + return merged_periods + + +register_result_extension(ComputeValidUnitPeriods) +compute_valid_unit_periods = ComputeValidUnitPeriods.function_factory() + + +global worker_ctx + + +def fp_fn_worker_init(sorting, all_amplitudes_by_unit, params, max_threads_per_worker): + global worker_ctx + worker_ctx = {} + + # cache spike vector and spiketrains + sorting.precompute_spike_trains() + + worker_ctx["sorting"] = sorting + worker_ctx["all_amplitudes_by_unit"] = all_amplitudes_by_unit + worker_ctx["params"] = params + worker_ctx["max_threads_per_worker"] = max_threads_per_worker + + +def fp_fn_worker_func(period, sorting, all_amplitudes_by_unit, params): + """ + Low level computation of false positives and false negatives for one period and one unit. + """ + from spikeinterface.metrics.quality.misc_metrics import ( + amplitude_cutoff, + _compute_nb_violations_numba, + _compute_rp_contamination_one_unit, + ) + + # period is of dtype unit_period_dtype: 0: segment_index, 1: start_sample_index, 2: end_sample_index, 3: unit_index + period_sample = period[0] + segment_index = period_sample["segment_index"] + start_sample_index = period_sample["start_sample_index"] + end_sample_index = period_sample["end_sample_index"] + unit_index = period_sample["unit_index"] + unit_id = sorting.unit_ids[unit_index] + + amplitudes_unit = all_amplitudes_by_unit[segment_index][unit_id] + spiketrain = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + + start_index, end_index = np.searchsorted(spiketrain, [start_sample_index, end_sample_index]) + total_samples_in_period = end_sample_index - start_sample_index + spiketrain_period = spiketrain[start_index:end_index] + amplitudes_period = amplitudes_unit[start_index:end_index] + + # compute fp (rp_violations). See _compute_refrac_period_violations in quality metrics + fs = sorting.sampling_frequency + t_c = int(round(params["censored_period_ms"] * fs * 1e-3)) + t_r = int(round(params["refractory_period_ms"] * fs * 1e-3)) + n_v = _compute_nb_violations_numba(spiketrain_period, t_r) + fp = _compute_rp_contamination_one_unit( + n_v, + len(spiketrain_period), + total_samples_in_period, + t_c, + t_r, + ) + + # compute fn (amplitude_cutoffs) + fn = amplitude_cutoff( + amplitudes_period, + params["num_histogram_bins"], + params["histogram_smoothing_value"], + params["amplitudes_bins_min_ratio"], + ) + return fp, fn + + +def fp_fn_worker_func_wrapper(period): + global worker_ctx + with threadpool_limits(limits=worker_ctx["max_threads_per_worker"]): + fp, fn = fp_fn_worker_func( + period, + worker_ctx["sorting"], + worker_ctx["all_amplitudes_by_unit"], + worker_ctx["params"], + ) + return fp, fn diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 757401d77c..d59193ed8b 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -327,7 +327,6 @@ def _full_update_plot(self, change=None): backend_kwargs = dict(figure=self.figure, axes=None, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) - self._update_plot() def _update_plot(self, change=None): for ax in self.axes.flatten(): @@ -346,9 +345,6 @@ def _update_plot(self, change=None): self.figure.canvas.flush_events() -import numpy as np - - class RasterWidget(BaseRasterWidget): """ Plots spike train rasters. diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 42f8b93d74..154d6e4ed3 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -66,6 +66,7 @@ def setUpClass(cls): templates=dict(), noise_levels=dict(), spike_amplitudes=dict(), + amplitude_scalings=dict(max_dense_channels=None), # required by valid unit periods unit_locations=dict(), spike_locations=dict(), quality_metrics=dict( @@ -74,6 +75,12 @@ def setUpClass(cls): template_metrics=dict(), correlograms=dict(), template_similarity=dict(), + valid_unit_periods=dict( + period_mode="relative", + period_target_num_spikes=200, + relative_margin_size=0.5, + min_num_periods_relative=5, + ), ) job_kwargs = dict(n_jobs=-1) @@ -687,6 +694,14 @@ def test_plot_motion_info(self): if backend not in self.skip_backends: sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) + def test_plot_valid_unit_periods(self): + possible_backends = list(sw.ValidUnitPeriodsWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_valid_unit_periods( + self.sorting_analyzer_dense, backend=backend, show_only_units_with_valid_periods=False + ) + if __name__ == "__main__": # unittest.main() diff --git a/src/spikeinterface/widgets/unit_valid_periods.py b/src/spikeinterface/widgets/unit_valid_periods.py new file mode 100644 index 0000000000..724803b162 --- /dev/null +++ b/src/spikeinterface/widgets/unit_valid_periods.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import numpy as np +from warnings import warn + +from spikeinterface.core import SortingAnalyzer +from .base import BaseWidget, to_attr + + +class ValidUnitPeriodsWidget(BaseWidget): + """ + Plots the valid periods for units based on valid periods extension. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer | None, default: None + The sorting analyzer + segment_index : None or int, default: None + The segment index. If None, uses first segment. + unit_ids : list | None, default: None + List of unit ids to plot. If None, all units are plotted. + show_only_units_with_valid_periods : bool, default: False + If True, only units with valid periods are shown. + clip_amplitude_scalings : float | None, default: 5.0 + Clip amplitude scalings for better visualization. If None, no clipping is applied. + """ + + def __init__( + self, + sorting_analyzer: SortingAnalyzer | None = None, + segment_index: int | None = None, + unit_ids: list | None = None, + show_only_units_with_valid_periods: bool = False, + clip_amplitude_scalings: float | None = 5.0, + backend: str | None = None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "valid_unit_periods") + valid_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") + if valid_periods_ext.params["method"] == "user_defined": + raise ValueError("UnitValidPeriodsWidget cannot be used with 'user_defined' valid periods.") + + valid_periods = valid_periods_ext.get_data(outputs="numpy") + if show_only_units_with_valid_periods: + valid_unit_ids = sorting_analyzer.unit_ids[np.unique(valid_periods["unit_index"])] + else: + valid_unit_ids = sorting_analyzer.unit_ids + if unit_ids is not None: + valid_unit_ids = [u for u in unit_ids if u in valid_unit_ids] + + if segment_index is None and sorting_analyzer.get_num_segments() == 1: + segment_index = 0 + + data_plot = dict( + sorting_analyzer=sorting_analyzer, + segment_index=segment_index, + unit_ids=valid_unit_ids, + clip_amplitude_scalings=clip_amplitude_scalings, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting_analyzer = dp.sorting_analyzer + num_units = len(dp.unit_ids) + segment_index = dp.segment_index + + if segment_index is None: + nseg = sorting_analyzer.get_num_segments() + if nseg != 1: + raise ValueError("You must provide segment_index=...") + else: + segment_index = 0 + + if backend_kwargs["axes"] is not None: + axes = backend_kwargs["axes"] + if axes.ndim == 1: + axes = axes[:, None] + assert np.asarray(axes).shape == (3, num_units), "Axes shape does not match number of units" + else: + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (2 * num_units, 2 * 3) + backend_kwargs["num_axes"] = num_units * 3 + backend_kwargs["ncols"] = num_units + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + sorting_analyzer = dp.sorting_analyzer + sampling_frequency = sorting_analyzer.sampling_frequency + segment_index = dp.segment_index + good_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") + fp_threshold = good_periods_ext.params["fp_threshold"] + fn_threshold = good_periods_ext.params["fn_threshold"] + good_periods = good_periods_ext.get_data(outputs="numpy") + good_periods = good_periods[good_periods["segment_index"] == segment_index] + fps, fns = good_periods_ext.get_fps_and_fns(unit_ids=dp.unit_ids) + period_centers = good_periods_ext.get_period_centers(unit_ids=dp.unit_ids) + + fps_segment = fps[segment_index] + fns_segment = fns[segment_index] + period_centers_segment = period_centers[segment_index] + + amp_scalings_ext = sorting_analyzer.get_extension("amplitude_scalings") + amp_scalings_by_unit = amp_scalings_ext.get_data(outputs="by_unit")[segment_index] + + for ui, unit_id in enumerate(dp.unit_ids): + fp = fps_segment[unit_id] + fn = fns_segment[unit_id] + period_centers = period_centers_segment[unit_id] + unit_index = list(sorting_analyzer.unit_ids).index(unit_id) + + axs = self.axes[:, ui] + # for simplicity we don't use timestamps here + spiketrain = ( + sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index) / sampling_frequency + ) + center_bins_s = np.array(period_centers) / sampling_frequency + + axs[0].plot(center_bins_s, fp, ls="", marker="o", color="r") + axs[0].axhline(fp_threshold, color="gray", ls="--") + axs[1].plot(center_bins_s, fn, ls="", marker="o") + axs[1].axhline(fn_threshold, color="gray", ls="--") + amp_scalings_data = amp_scalings_by_unit[unit_id] + if dp.clip_amplitude_scalings is not None: + amp_scalings_data = np.clip(amp_scalings_data, -dp.clip_amplitude_scalings, dp.clip_amplitude_scalings) + axs[2].plot(spiketrain, amp_scalings_data, ls="", marker="o", color="gray", alpha=0.5) + axs[2].axhline(1.0, color="k", ls="--") + # plot valid periods + valid_period_for_units = good_periods[good_periods["unit_index"] == unit_index] + for valid_period in valid_period_for_units: + start_time = valid_period["start_sample_index"] / sorting_analyzer.sampling_frequency + end_time = valid_period["end_sample_index"] / sorting_analyzer.sampling_frequency + axs[2].axvspan(start_time, end_time, alpha=0.3, color="g") + + axs[0].set_xlabel("") + axs[1].set_xlabel("") + axs[2].set_xlabel("Time (s)") + axs[0].set_ylabel("FP Rate (RP violations)") + axs[1].set_ylabel("FN Rate (Amp. cutoff)") + axs[2].set_ylabel("Amplitude Scaling") + axs[0].set_title(f"Unit {unit_id}") + + axs[1].sharex(axs[0]) + axs[2].sharex(axs[0]) + + for ax in self.axes.flatten(): + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + self.figure.subplots_adjust(hspace=0.4) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + analyzer = data_plot["sorting_analyzer"] + + cm = 1 / 2.54 + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + # Create figure without axes - let plot_matplotlib create them + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] + + if analyzer.get_num_segments() > 1: + num_segments = analyzer.get_num_segments() + segment_value = 0 if data_plot["segment_index"] is None else data_plot["segment_index"] + self.segment_selector = widgets.Dropdown( + description="segment", + options=list(range(num_segments)), + value=segment_value, + width="100px", + height="50px", + ) + self.segment_selector.observe(self._update_plot, names="value", type="change") + else: + self.segment_selector = None + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + left_sidebar=self.unit_selector, + pane_widths=ratios + [0], + footer=self.segment_selector, + ) + + # a first update + self._full_update_plot() + + self.unit_selector.observe(self._update_plot, names=["value"], type="change") + + if backend_kwargs["display"]: + display(self.widget) + + def _full_update_plot(self, change=None): + self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + if self.segment_selector is not None: + data_plot["segment_index"] = self.segment_selector.value + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + + def _update_plot(self, change=None): + for ax in self.axes.flatten(): + ax.clear() + + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + if self.segment_selector is not None: + data_plot["segment_index"] = self.segment_selector.value + + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6edba67c96..e74ad38053 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -37,6 +37,7 @@ from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget +from .unit_valid_periods import ValidUnitPeriodsWidget widget_list = [ AgreementMatrixWidget, @@ -48,6 +49,7 @@ CrossCorrelogramsWidget, DriftingTemplatesWidget, DriftRasterMapWidget, + ValidUnitPeriodsWidget, ISIDistributionWidget, LocationsWidget, MotionWidget, @@ -128,6 +130,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_drifting_templates = DriftingTemplatesWidget plot_drift_raster_map = DriftRasterMapWidget +plot_valid_unit_periods = ValidUnitPeriodsWidget plot_isi_distribution = ISIDistributionWidget plot_locations = LocationsWidget plot_motion = MotionWidget