Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,10 +766,6 @@ def _compute_and_cache_spike_vector(self) -> None:
if len(sample_indices) > 0:
sample_indices = np.concatenate(sample_indices, dtype="int64")
unit_indices = np.concatenate(unit_indices, dtype="int64")
order = np.argsort(sample_indices)
sample_indices = sample_indices[order]
unit_indices = unit_indices[order]

n = sample_indices.size
segment_slices[segment_index, 0] = seg_pos
segment_slices[segment_index, 1] = seg_pos + n
Expand All @@ -783,7 +779,9 @@ def _compute_and_cache_spike_vector(self) -> None:
spikes_in_seg["unit_index"] = unit_indices
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuelgarcia I think it's better to lexsort the whole array at the end to make sure that units are also sorted within samples


self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = segment_slices
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def generate_sorting(
spikes.append(spikes_on_borders)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))]
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were missing the unit index lexsort in generate


sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

Expand Down
47 changes: 23 additions & 24 deletions src/spikeinterface/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,20 @@ def check_sortings_equal(

max_spike_index = SX1.to_spike_vector()["sample_index"].max()

# TODO for later use to_spike_vector() to do this without looping
for segment_idx in range(SX1.get_num_segments()):
# get_unit_ids
ids1 = np.sort(np.array(SX1.get_unit_ids()))
ids2 = np.sort(np.array(SX2.get_unit_ids()))
assert_array_equal(ids1, ids2)
for id in ids1:
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx))
assert np.array_equal(train1, train2)
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
assert np.array_equal(train1, train2)
# test that slicing works correctly
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
assert np.array_equal(train1, train2)
train1 = np.sort(
SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
)
train2 = np.sort(
SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
)
assert np.array_equal(train1, train2)
s1 = SX1.to_spike_vector()
s2 = SX2.to_spike_vector()
assert_array_equal(s1, s2)

for start_frame, end_frame in [
(None, None),
(30, None),
(None, max_spike_index - 30),
(30, max_spike_index - 30),
]:

slice1 = _slice_spikes(s1, start_frame, end_frame)
slice2 = _slice_spikes(s2, start_frame, end_frame)
assert np.array_equal(slice1, slice2)

if check_annotations:
check_extractor_annotations_equal(SX1, SX2)
Expand All @@ -155,3 +145,12 @@ def check_extractor_properties_equal(EX1, EX2) -> None:

for property_name in EX1.get_property_keys():
assert_array_equal(EX1.get_property(property_name), EX2.get_property(property_name))


def _slice_spikes(spikes, start_frame=None, end_frame=None):
mask = np.ones(spikes.size, dtype=bool)
if start_frame is not None:
mask &= spikes["sample_index"] >= start_frame
if end_frame is not None:
mask &= spikes["sample_index"] <= end_frame
return spikes[mask]
2 changes: 2 additions & 0 deletions src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None,
spikes["unit_index"] = spikes_group["unit_index"][:]
for i, (start, end) in enumerate(segment_slices_list):
spikes["segment_index"][start:end] = i
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
self._cached_spike_vector = spikes
Comment on lines +298 to +299
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and in zarr


for segment_index in range(num_segments):
soring_segment = SpikeVectorSortingSegment(spikes, segment_index, unit_ids)
Expand Down
Loading