diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 42f98b3473..e17731c70e 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -766,10 +766,6 @@ def _compute_and_cache_spike_vector(self) -> None: if len(sample_indices) > 0: sample_indices = np.concatenate(sample_indices, dtype="int64") unit_indices = np.concatenate(unit_indices, dtype="int64") - order = np.argsort(sample_indices) - sample_indices = sample_indices[order] - unit_indices = unit_indices[order] - n = sample_indices.size segment_slices[segment_index, 0] = seg_pos segment_slices[segment_index, 1] = seg_pos + n @@ -783,7 +779,9 @@ def _compute_and_cache_spike_vector(self) -> None: spikes_in_seg["unit_index"] = unit_indices spikes_in_seg["segment_index"] = segment_index spikes.append(spikes_in_seg) + spikes = np.concatenate(spikes) + spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index eddbe318a6..5d7ca1917a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -174,7 +174,7 @@ def generate_sorting( spikes.append(spikes_on_borders) spikes = np.concatenate(spikes) - spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))] + spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index 3f311f1bdd..57b6397863 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -112,30 +112,20 @@ def check_sortings_equal( max_spike_index = SX1.to_spike_vector()["sample_index"].max() - # TODO for later use to_spike_vector() to do this without looping - for segment_idx in range(SX1.get_num_segments()): - # get_unit_ids - ids1 = np.sort(np.array(SX1.get_unit_ids())) - ids2 = np.sort(np.array(SX2.get_unit_ids())) - assert_array_equal(ids1, ids2) - for id in ids1: - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx)) - assert np.array_equal(train1, train2) - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30)) - assert np.array_equal(train1, train2) - # test that slicing works correctly - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30)) - assert np.array_equal(train1, train2) - train1 = np.sort( - SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30) - ) - train2 = np.sort( - SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30) - ) - assert np.array_equal(train1, train2) + s1 = SX1.to_spike_vector() + s2 = SX2.to_spike_vector() + assert_array_equal(s1, s2) + + for start_frame, end_frame in [ + (None, None), + (30, None), + (None, max_spike_index - 30), + (30, max_spike_index - 30), + ]: + + slice1 = _slice_spikes(s1, start_frame, end_frame) + slice2 = _slice_spikes(s2, start_frame, end_frame) + assert np.array_equal(slice1, slice2) if check_annotations: check_extractor_annotations_equal(SX1, SX2) @@ -155,3 +145,12 @@ def check_extractor_properties_equal(EX1, EX2) -> None: for property_name in EX1.get_property_keys(): assert_array_equal(EX1.get_property(property_name), EX2.get_property(property_name)) + + +def _slice_spikes(spikes, start_frame=None, end_frame=None): + mask = np.ones(spikes.size, dtype=bool) + if start_frame is not None: + mask &= spikes["sample_index"] >= start_frame + if end_frame is not None: + mask &= spikes["sample_index"] <= end_frame + return spikes[mask] diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index c0c9f9b5d7..a13a54bf48 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -295,6 +295,8 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, spikes["unit_index"] = spikes_group["unit_index"][:] for i, (start, end) in enumerate(segment_slices_list): spikes["segment_index"][start:end] = i + spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + self._cached_spike_vector = spikes for segment_index in range(num_segments): soring_segment = SpikeVectorSortingSegment(spikes, segment_index, unit_ids)