diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..3e2d6583c0 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,3 +1,10 @@ +from spikeinterface.postprocessing.principal_component import ComputePrincipalComponents +from spikeinterface.preprocessing.scale import scale_to_uV +from spikeinterface.core.analyzer_extension_core import ComputeWaveforms +from pandas.tests.tseries.offsets.test_business_day import offset + +from typing import Optional + from pathlib import Path import warnings @@ -14,6 +21,7 @@ SortingAnalyzer, ) from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.template_tools import get_template_extremum_amplitude from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations from probeinterface import read_prb, Probe @@ -314,7 +322,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") -def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer: +def read_kilosort_as_analyzer( + folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None, recording=None +) -> SortingAnalyzer: """ Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and above are supported. The function may work on older versions of Kilosort output, @@ -370,21 +380,25 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse else: AssertionError(f"Cannot read probe layout from folder {phy_path}.") - # to make the initial analyzer, we'll use a fake recording and set it to None later - recording, _ = generate_ground_truth_recording( - probe=probe, - sampling_frequency=sampling_frequency, - durations=[duration], - num_units=1, - seed=1205, - ) + if recording is None: + # to make the initial analyzer, we'll use a fake recording and set it to None later + recording_for_analyzer, _ = generate_ground_truth_recording( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + num_units=1, + seed=1205, + ) + else: + recording_for_analyzer = recording - sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) + # sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) + sparsity = _make_sparsity_from_pcs(recording_for_analyzer, sorting, phy_path) - sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) + sorting_analyzer = create_sorting_analyzer(sorting, recording_for_analyzer, sparse=True, sparsity=sparsity) # first compute random spikes. These do nothing, but are needed for si-gui to run - sorting_analyzer.compute("random_spikes") + sorting_analyzer.compute("random_spikes", method="all") _make_templates( sorting_analyzer, @@ -396,13 +410,208 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse unwhiten=unwhiten, ) _make_locations(sorting_analyzer, phy_path) + _make_amplitudes(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) + # _make_waveforms(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) + # _make_principal_components(sorting_analyzer, phy_path) + + if recording is None: + sorting_analyzer._recording = None - sorting_analyzer._recording = None return sorting_analyzer +def _make_principal_components(sorting_analyzer, kilosort_output_path): + + pcs_extension = ComputePrincipalComponents(sorting_analyzer) + + pcs = np.load(Path(kilosort_output_path) / "pc_features.npy") + + print(f"{pcs.shape=}") + + pcs_extension.data = {"pca_projection": pcs} + pcs_extension.params = {"n_components": 6, "mode": "by_channel_local_kilosort", "whiten": True, "dtype": "float32"} + pcs_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["principal_components"] = pcs_extension + + +def _make_waveforms(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_to_uV): + + waveforms_extension = ComputeWaveforms(sorting_analyzer) + + pcs = np.load(Path(kilosort_output_path) / "pc_features.npy") + ops_path = kilosort_output_path / "ops.npy" + if ops_path.is_file(): + ops = np.load(ops_path, allow_pickle=True) + wPCA = ops.tolist()["wPCA"] + pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") + whitened_waveforms = _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) + + spike_vector = sorting_analyzer.sorting.to_spike_vector() + + wht_inv_per_unit = [ + wh_inv[pc_inds[unit_index], :][:, pc_inds[unit_index]] for unit_index, _ in enumerate(sorting_analyzer.unit_ids) + ] + + order_per_unit_index = {} + for unit_index, _ in enumerate(sorting_analyzer.unit_ids): + channel_order = sorting_analyzer.recording.ids_to_indices( + sorting_analyzer.channel_ids[sorting_analyzer.sparsity.mask[unit_index]] + ) + order_per_unit_index[unit_index] = np.array( + [list(pc_inds[unit_index]).index(channel_ind) for channel_ind in channel_order] + ) + + correct_waveforms = np.empty_like(whitened_waveforms) + for waveform_index, (whitened_waveform, unit_index) in enumerate( + zip(whitened_waveforms, spike_vector["unit_index"]) + ): + unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], whitened_waveform) + correct_waveforms[waveform_index, :, :] = unwhitened_waveform[:, order_per_unit_index[unit_index]] + + waveforms_extension.data = {"waveforms": correct_waveforms * gain_to_uV + offset_to_uV} + waveforms_extension.params = {} + waveforms_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["waveforms"] = waveforms_extension + + +def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_to_uV): + """Constructs approximate `spike_amplitudes` extension from the amplitudes numpy array + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" + + amplitudes_extension = ComputeSpikeAmplitudes(sorting_analyzer) + + spike_amplitudes_path = kilosort_output_path / "amplitudes.npy" + if spike_amplitudes_path.is_file(): + amps_np = np.load(spike_amplitudes_path) + if amps_np.ndim == 2: + amps_np = np.transpose(amps_np)[0] + else: + return + + # Check that the spike amplitudes vector is the same size as the spike vector + num_spikes = len(sorting_analyzer.sorting.to_spike_vector()) + num_spike_amps = len(amps_np) + if num_spikes != num_spike_amps: + warnings.warn( + "The number of spikes does not match the number of spike amplitudes in `amplitudes.npy`. Skipping spike amplitudes." + ) + return + + spike_vector = sorting_analyzer.sorting.to_spike_vector() + unit_indices = spike_vector["unit_index"] + + pcs = np.load(Path(kilosort_output_path) / "pc_features.npy") + ops_path = kilosort_output_path / "ops.npy" + if ops_path.is_file(): + ops = np.load(ops_path, allow_pickle=True) + wPCA = ops.tolist()["wPCA"] + + pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") + + # this is the inverse whitening matrix for each channel + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") + + # since each unit has a channel sparsity, we need different inverse whitening matrices for each unit + wht_inv_per_unit = [ + wh_inv[pc_inds[unit_index], :][:, pc_inds[unit_index]] for unit_index, _ in enumerate(sorting_analyzer.unit_ids) + ] + + min_of_waveforms = np.zeros(len(spike_vector)) + for spike_index, (unit_index, pc) in enumerate(zip(unit_indices, pcs, strict=True)): + whitened_waveform = np.einsum("ji,jk->ik", wPCA, pc) + unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], whitened_waveform) + min_of_waveforms[spike_index] = np.min(unwhitened_waveform) + + # WIP code for amplitude scalings + + # whitened_waveforms = _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) + + # for spike_index, (unit_index, waveform) in enumerate(zip(unit_indices, whitened_waveforms)): + # unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], waveform) + # min_of_waveforms[spike_index] = np.min(unwhitened_waveform) + + # for spike_index, (unit_index, waveform) in enumerate(zip(unit_indices, whitened_waveforms)): + # unwhitened_template = unwhitened_templates[unit_index, :, pc_inds[unit_index]] + # whitened_template = whitened_templates[unit_index, :, pc_inds[unit_index]] + # correct_waveform = unwhitened_waveform[:,order_per_unit_index[unit_index]] + + # scaling_factor[spike_index] = np.einsum("ij,ji", correct_waveform, unwhitened_template) / np.einsum( + # "ij,ij", unwhitened_template, unwhitened_template + # ) + + # whitened_templates = np.load(kilosort_output_path / "templates.npy") + + # rescale the amplitudes to physical units, by computing a conversion factor per unit + # based on the ratio between the `absmax`s the unwhitened and whitened templates + # unwhitened_templates = _compute_unwhitened_templates( + # whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=1.0, offset_to_uV=0.0 + # ) + + # order_per_unit_index = {} + # for unit_index, _ in enumerate(sorting_analyzer.unit_ids): + # channel_order = sorting_analyzer.recording.ids_to_indices( + # sorting_analyzer.channel_ids[sorting_analyzer.sparsity.mask[unit_index]] + # ) + # order_per_unit_index[unit_index] = np.array( + # [list(pc_inds[unit_index]).index(channel_ind) for channel_ind in channel_order] + # ) + + # neg_amps = np.max(np.max(np.abs(whitened_waveforms), axis=2), axis=1) + # neg_amps = np.min(np.min(np.einsum("bc,ji,ajk->aik", wh_inv, wPCA, pcs), axis=2), axis=1) + + # if False: + # spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() + # scaling_factors = np.zeros(num_spikes) + # for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): + + # whitened_template = whitened_templates[unit_ind, :, pc_inds[unit_ind]] + # whitened_extremum = np.nanmax(whitened_template) - np.nanmin(whitened_template) + + # unwhitened_template = unwhitened_templates[unit_ind, :, pc_inds[unit_ind]] + # # unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) + # # note: we don't `abs` the extrema so that the amps have the expected sign + # # unwhitened_extremum = unwhitened_template[np.unravel_index(unwhitened_extremum_absargmax, unwhitened_template.shape)] + # unwhitened_extremum = np.nanmax(unwhitened_template) - np.nanmin(unwhitened_template) + + # # print(f"{whitened_extremum=}, {unwhitened_extremum=}") + + # conversion_factor = unwhitened_extremum # / whitened_extremum + # # print(f"unit {unit_id} conversion factor: {conversion_factor}") + + # # kilosort always has one segment, so always choose 0 segment index + # inds = spike_indices[0][unit_id] + # scaling_factors[inds] = conversion_factor + + # scaled_amps = scaling_factor * scaling_factors + + amplitudes_extension.data = { + "amplitudes": min_of_waveforms * gain_to_uV + offset_to_uV + } # * scaling_factors * gain_to_uV} + amplitudes_extension.params = {} + amplitudes_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension + + +def _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv): + + # sparsity_mask = sorting_analyzer.sparsity.mask + + # spike_vector = sorting_analyzer.sorting.to_spike_vector + + # unit_sparsity_masks = {unit_id: sorting_analyzer.sparsity.mask[unit_index] for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids)} + # inv_white_per_unit = [wh_inv[unit_sparsity_masks[unit_index],:][:,unit_sparsity_masks[unit_index]].shape for unit_index, _ enumerate(sorting_analyzer.unit_ids) + + whitened_waveforms = np.einsum("ji,ajk->aik", wPCA, pcs) + + return whitened_waveforms + + def _make_locations(sorting_analyzer, kilosort_output_path): - """Constructs a `spike_locations` extension from the amplitudes numpy array + """Constructs a `spike_locations` extension from the locations numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" locations_extension = ComputeSpikeLocations(sorting_analyzer) @@ -452,6 +661,20 @@ def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) +def _make_sparsity_from_pcs(recording, sorting, kilosort_output_path): + """Constructs the `ChannelSparsity` of from kilosort output, by seeing if the + templates output is zero or not on all channels.""" + + pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") + unit_ids_to_channel_ids = { + unit_id: recording.channel_ids[pc_inds[unit_index]] for unit_index, unit_id in enumerate(sorting.unit_ids) + } + sparsity = ChannelSparsity.from_unit_id_to_channel_ids( + unit_ids_to_channel_ids, sorting.unit_ids, recording.channel_ids + ) + return sparsity + + def _make_templates( sorting_analyzer, kilosort_output_path, mask, sampling_frequency, gain_to_uV, offset_to_uV, unwhiten=True ): @@ -512,6 +735,7 @@ def _compute_unwhitened_templates(whitened_templates, wh_inv, gain_to_uV, offset # templates have dimension (num units) x (num samples) x (num channels) # whitening inverse has dimension (num channels) x (num channels) # to undo whitening, we need do matrix multiplication on the channel index + unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates) # then scale to physical units