Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")


def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer:
def read_kilosort_as_analyzer(
folder_path, recording=None, unwhiten=True, gain_to_uV=None, offset_to_uV=None
) -> SortingAnalyzer:
"""
Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and
above are supported. The function may work on older versions of Kilosort output,
Expand All @@ -324,6 +326,8 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
----------
folder_path : str or Path
Path to the output Phy folder (containing the params.py).
recording : BaseRecording
A spikeinterface Recording object which will be attached to the analyzer
unwhiten : bool, default: True
Unwhiten the templates computed by kilosort.
gain_to_uV : float | None, default: None
Expand Down Expand Up @@ -366,18 +370,32 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
probe = Probe(si_units="um")
channel_positions = np.load(phy_path / "channel_positions.npy")
probe.set_contacts(channel_positions)
probe.set_device_channel_indices(range(probe.get_contact_count()))
channel_map = np.load(phy_path / "channel_map.npy")
probe.set_device_channel_indices(channel_map)
else:
AssertionError(f"Cannot read probe layout from folder {phy_path}.")

# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)
# Check that user-defined recording probe geometry is consistent with phy output
if recording is not None:
for recording_channel_location, probe_contact_position in zip(
recording.get_channel_locations(), probe.contact_positions
):
if not np.all(recording_channel_location == probe_contact_position):
raise ValueError(
"Recording channel locations from `recording` do not match probe channel locations from `folder_path/probe.prb`."
"Hence there is an inconsistency between probe layout or wiring between the recording and sorting output."
"Please resolve this inconsistency."
)

if recording is None:
# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)

sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)

Expand All @@ -397,7 +415,9 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
)
_make_locations(sorting_analyzer, phy_path)

sorting_analyzer._recording = None
if recording is None:
sorting_analyzer._recording = None

return sorting_analyzer


Expand All @@ -413,14 +433,9 @@ def _make_locations(sorting_analyzer, kilosort_output_path):
else:
return

# Check that the spike locations vector is the same size as the spike vector
# When recording is given, need to trim spike locations to match spikes in sorting
num_spikes = len(sorting_analyzer.sorting.to_spike_vector())
num_spike_locs = len(locs_np)
if num_spikes != num_spike_locs:
warnings.warn(
"The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations."
)
return
locs_np = locs_np[:num_spikes]

num_dims = len(locs_np[0])
column_names = ["x", "y", "z"][:num_dims]
Expand Down
Loading