diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..e6c64dccb3 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -314,7 +314,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") -def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer: +def read_kilosort_as_analyzer( + folder_path, recording=None, unwhiten=True, gain_to_uV=None, offset_to_uV=None +) -> SortingAnalyzer: """ Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and above are supported. The function may work on older versions of Kilosort output, @@ -324,6 +326,8 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse ---------- folder_path : str or Path Path to the output Phy folder (containing the params.py). + recording : BaseRecording + A spikeinterface Recording object which will be attached to the analyzer unwhiten : bool, default: True Unwhiten the templates computed by kilosort. gain_to_uV : float | None, default: None @@ -366,18 +370,32 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse probe = Probe(si_units="um") channel_positions = np.load(phy_path / "channel_positions.npy") probe.set_contacts(channel_positions) - probe.set_device_channel_indices(range(probe.get_contact_count())) + channel_map = np.load(phy_path / "channel_map.npy") + probe.set_device_channel_indices(channel_map) else: AssertionError(f"Cannot read probe layout from folder {phy_path}.") - # to make the initial analyzer, we'll use a fake recording and set it to None later - recording, _ = generate_ground_truth_recording( - probe=probe, - sampling_frequency=sampling_frequency, - durations=[duration], - num_units=1, - seed=1205, - ) + # Check that user-defined recording probe geometry is consistent with phy output + if recording is not None: + for recording_channel_location, probe_contact_position in zip( + recording.get_channel_locations(), probe.contact_positions + ): + if not np.all(recording_channel_location == probe_contact_position): + raise ValueError( + "Recording channel locations from `recording` do not match probe channel locations from `folder_path/probe.prb`." + "Hence there is an inconsistency between probe layout or wiring between the recording and sorting output." + "Please resolve this inconsistency." + ) + + if recording is None: + # to make the initial analyzer, we'll use a fake recording and set it to None later + recording, _ = generate_ground_truth_recording( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + num_units=1, + seed=1205, + ) sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) @@ -397,7 +415,9 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse ) _make_locations(sorting_analyzer, phy_path) - sorting_analyzer._recording = None + if recording is None: + sorting_analyzer._recording = None + return sorting_analyzer @@ -413,14 +433,9 @@ def _make_locations(sorting_analyzer, kilosort_output_path): else: return - # Check that the spike locations vector is the same size as the spike vector + # When recording is given, need to trim spike locations to match spikes in sorting num_spikes = len(sorting_analyzer.sorting.to_spike_vector()) - num_spike_locs = len(locs_np) - if num_spikes != num_spike_locs: - warnings.warn( - "The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations." - ) - return + locs_np = locs_np[:num_spikes] num_dims = len(locs_np[0]) column_names = ["x", "y", "z"][:num_dims]