From e139fd3ee6886b7325fef0b64bae8de04f5a683b Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 21 Jan 2026 11:09:01 +0000 Subject: [PATCH 1/9] add ampltidues to ks converter --- .../extractors/phykilosortextractors.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 68b16074fb..b1d0f393e9 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -17,6 +17,7 @@ SortingAnalyzer, ) from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.template_tools import get_template_extremum_amplitude from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations from probeinterface import read_prb, Probe @@ -376,13 +377,57 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) _make_locations(sorting_analyzer, phy_path) + _make_amplitudes(sorting_analyzer, phy_path) sorting_analyzer._recording = None return sorting_analyzer +def _make_amplitudes(sorting_analyzer, kilosort_output_path): + """Constructs approximate `spike_amplitudes` extension from the amplitudes numpy array + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" + + amplitudes_extension = ComputeSpikeAmplitudes(sorting_analyzer) + + spike_amplitudes_path = kilosort_output_path / "amplitudes.npy" + if spike_amplitudes_path.is_file(): + amps_np = np.load(spike_amplitudes_path) + if amps_np.ndim == 2: + amps_np = np.transpose(amps_np)[0] + else: + return + + # Check that the spike amplitudes vector is the same size as the spike vector + num_spikes = len(sorting_analyzer.sorting.to_spike_vector()) + num_spike_amps = len(amps_np) + if num_spikes != num_spike_amps: + warnings.warn( + "The number of spikes does not match the number of spike amplitudes in `amplitudes.npy`. Skipping spike amplitudes." + ) + return + + # rescale the amplitudes to the scale of the templates + peak_to_peak_amps = get_template_extremum_amplitude(sorting_analyzer, peak_sign="both", mode="extremum") + spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() + scaling_factors = np.zeros(num_spikes) + for unit_id in sorting_analyzer.unit_ids: + # kilosort always has one segment, so always choose 0 segment index + inds = spike_indices[0][unit_id] + amps_in_unit = amps_np[inds] + median_amp_in_unit = np.median(amps_in_unit) + scaling_factors[inds] = peak_to_peak_amps[unit_id] / median_amp_in_unit + + scaled_amps = amps_np * scaling_factors + + amplitudes_extension.data = {"amplitudes": scaled_amps} + amplitudes_extension.params = {} + amplitudes_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension + + def _make_locations(sorting_analyzer, kilosort_output_path): - """Constructs a `spike_locations` extension from the amplitudes numpy array + """Constructs a `spike_locations` extension from the locations numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" locations_extension = ComputeSpikeLocations(sorting_analyzer) From 238557a18a332372a347c28ad5f447d93c88b007 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 22 Jan 2026 14:18:20 +0000 Subject: [PATCH 2/9] improve amplitude estimation --- .../extractors/phykilosortextractors.py | 71 +++++++++++++++---- 1 file changed, 56 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index b1d0f393e9..5ce55a6eee 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,4 +1,5 @@ from __future__ import annotations +from pandas.tests.tseries.offsets.test_business_day import offset from typing import Optional from pathlib import Path @@ -318,7 +319,7 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") -def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: +def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer: """ Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and above are supported. The function may work on older versions of Kilosort output, @@ -330,6 +331,10 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: Path to the output Phy folder (containing the params.py). unwhiten : bool, default: True Unwhiten the templates computed by kilosort. + gain_to_uV : float | None, default: None + The gain to apply to convert traces to uV + offset_to_uV : float | None, default: None + The offset to apply to the traces Returns ------- @@ -337,6 +342,17 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: A SortingAnalyzer object. """ + if gain_to_uV is None: + warnings.warn( + f"No `gain_to_uv` value given. Outputted data will be in dimensionless units. If you know the conversion factor, please pass it to the `read_kilosort_as_analyzer` function." + ) + gain_to_uV = 1.0 + if offset_to_uV is None: + warnings.warn( + f"No `offset_to_uV` value given. Outputted data may not be offset correctly. If you know the offset factor, please pass it to the `read_kilosort_as_analyzer` function." + ) + offset_to_uV = 1.0 + phy_path = Path(folder_path) sorting = read_phy(phy_path) @@ -375,15 +391,17 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: # first compute random spikes. These do nothing, but are needed for si-gui to run sorting_analyzer.compute("random_spikes") - _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) + _make_templates( + sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, gain_to_uV, offset_to_uV, unwhiten=unwhiten + ) _make_locations(sorting_analyzer, phy_path) - _make_amplitudes(sorting_analyzer, phy_path) + _make_amplitudes(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) sorting_analyzer._recording = None return sorting_analyzer -def _make_amplitudes(sorting_analyzer, kilosort_output_path): +def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_to_uV): """Constructs approximate `spike_amplitudes` extension from the amplitudes numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" @@ -406,18 +424,33 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path): ) return - # rescale the amplitudes to the scale of the templates - peak_to_peak_amps = get_template_extremum_amplitude(sorting_analyzer, peak_sign="both", mode="extremum") + # rescale the amplitudes to physical units, by computing a conversion factor per unit + # based on the ratio between the `absmax`s the unwhitened and whitened templates + whitened_templates = np.load(kilosort_output_path / "templates.npy") + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") + unwhitened_templates = _compute_unwhitened_templates( + whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV + ) + spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() scaling_factors = np.zeros(num_spikes) - for unit_id in sorting_analyzer.unit_ids: + for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): + + whitened_template = whitened_templates[unit_ind, :, :] + whitened_extremum = np.nanmax(np.abs(whitened_template)) + + unwhitened_template = unwhitened_templates[unit_ind, :, :] + unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) + # note: we don't `abs` the extrema so that the amps have the expected sign + unwhitened_extremum = unwhitened_template[unwhitened_extremum_absargmax] + + conversion_factor = unwhitened_extremum / whitened_extremum + # kilosort always has one segment, so always choose 0 segment index inds = spike_indices[0][unit_id] - amps_in_unit = amps_np[inds] - median_amp_in_unit = np.median(amps_in_unit) - scaling_factors[inds] = peak_to_peak_amps[unit_id] / median_amp_in_unit + scaling_factors[inds] = conversion_factor - scaled_amps = amps_np * scaling_factors + scaled_amps = amps_np * scaling_factors * gain_to_uV + offset_to_uV amplitudes_extension.data = {"amplitudes": scaled_amps} amplitudes_extension.params = {} @@ -477,7 +510,9 @@ def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) -def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True): +def _make_templates( + sorting_analyzer, kilosort_output_path, mask, sampling_frequency, gain_to_uV, offset_to_uV, unwhiten=True +): """Constructs a `templates` extension from the amplitudes numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" @@ -485,7 +520,11 @@ def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequ whitened_templates = np.load(kilosort_output_path / "templates.npy") wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") - new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv) if unwhiten else whitened_templates + new_templates = ( + _compute_unwhitened_templates(whitened_templates, wh_inv, gain_to_uV, offset_to_uV) + if unwhiten + else whitened_templates + ) template_extension.data = {"average": new_templates} @@ -493,6 +532,7 @@ def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequ if ops_path.is_file(): ops = np.load(ops_path, allow_pickle=True) + print(f"{ops=}") number_samples_before_template_peak = ops.item(0)["nt0min"] total_template_samples = ops.item(0)["nt"] @@ -524,7 +564,7 @@ def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequ sorting_analyzer.extensions["templates"] = template_extension -def _compute_unwhitened_templates(whitened_templates, wh_inv): +def _compute_unwhitened_templates(whitened_templates, wh_inv, gain_to_uV, offset_to_uV): """Constructs unwhitened templates from whitened_templates, by applying an inverse whitening matrix.""" @@ -533,4 +573,5 @@ def _compute_unwhitened_templates(whitened_templates, wh_inv): # to undo whitening, we need do matrix multiplication on the channel index unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates) - return unwhitened_templates + # then scale to physical units + return unwhitened_templates * gain_to_uV + offset_to_uV From d40960fed5765d40e8c840344bfeaad730d89d22 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 23 Jan 2026 16:10:51 +0000 Subject: [PATCH 3/9] wip --- .../extractors/phykilosortextractors.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 5ce55a6eee..3cce32b32b 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -424,6 +424,14 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ ) return + pcs = np.load(Path(kilosort_output_path) / "pc_features.npy") + ops_path = kilosort_output_path / "ops.npy" + if ops_path.is_file(): + ops = np.load(ops_path, allow_pickle=True) + wPCA = ops.tolist()["wPCA"] + + neg_amps = np.min(np.min(np.einsum("ji,ajk->aik", wPCA, pcs), axis=2), axis=1) + # rescale the amplitudes to physical units, by computing a conversion factor per unit # based on the ratio between the `absmax`s the unwhitened and whitened templates whitened_templates = np.load(kilosort_output_path / "templates.npy") @@ -434,25 +442,25 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() scaling_factors = np.zeros(num_spikes) - for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): + # for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - whitened_template = whitened_templates[unit_ind, :, :] - whitened_extremum = np.nanmax(np.abs(whitened_template)) + # whitened_template = whitened_templates[unit_ind, :, :] + # whitened_extremum = np.nanmax(np.abs(whitened_template)) - unwhitened_template = unwhitened_templates[unit_ind, :, :] - unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) - # note: we don't `abs` the extrema so that the amps have the expected sign - unwhitened_extremum = unwhitened_template[unwhitened_extremum_absargmax] + # unwhitened_template = unwhitened_templates[unit_ind, :, :] + # unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) + # note: we don't `abs` the extrema so that the amps have the expected sign + # unwhitened_extremum = unwhitened_template[unwhitened_extremum_absargmax] - conversion_factor = unwhitened_extremum / whitened_extremum + # conversion_factor = unwhitened_extremum / whitened_extremum - # kilosort always has one segment, so always choose 0 segment index - inds = spike_indices[0][unit_id] - scaling_factors[inds] = conversion_factor + # kilosort always has one segment, so always choose 0 segment index + # inds = spike_indices[0][unit_id] + # scaling_factors[inds] = conversion_factor - scaled_amps = amps_np * scaling_factors * gain_to_uV + offset_to_uV + # scaled_amps = amps_np * scaling_factors * gain_to_uV + offset_to_uV - amplitudes_extension.data = {"amplitudes": scaled_amps} + amplitudes_extension.data = {"amplitudes": neg_amps} amplitudes_extension.params = {} amplitudes_extension.run_info = {"run_completed": True} @@ -532,7 +540,6 @@ def _make_templates( if ops_path.is_file(): ops = np.load(ops_path, allow_pickle=True) - print(f"{ops=}") number_samples_before_template_peak = ops.item(0)["nt0min"] total_template_samples = ops.item(0)["nt"] From 1f954d77f14f92b13f6f78eb1084ce57985c5db1 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Fri, 23 Jan 2026 17:06:15 +0000 Subject: [PATCH 4/9] wip: use pc_features --- .../extractors/phykilosortextractors.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 3cce32b32b..98e452a2c1 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -430,7 +430,9 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ ops = np.load(ops_path, allow_pickle=True) wPCA = ops.tolist()["wPCA"] + # wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") neg_amps = np.min(np.min(np.einsum("ji,ajk->aik", wPCA, pcs), axis=2), axis=1) + # neg_amps = np.min(np.min(np.einsum("bc,ji,ajk->aik", wh_inv, wPCA, pcs), axis=2), axis=1) # rescale the amplitudes to physical units, by computing a conversion factor per unit # based on the ratio between the `absmax`s the unwhitened and whitened templates @@ -440,27 +442,29 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV ) - spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() - scaling_factors = np.zeros(num_spikes) - # for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): + if True: + spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() + scaling_factors = np.zeros(num_spikes) + for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - # whitened_template = whitened_templates[unit_ind, :, :] - # whitened_extremum = np.nanmax(np.abs(whitened_template)) + whitened_template = whitened_templates[unit_ind, :, :] + whitened_extremum = np.nanmax(np.abs(whitened_template)) - # unwhitened_template = unwhitened_templates[unit_ind, :, :] - # unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) - # note: we don't `abs` the extrema so that the amps have the expected sign - # unwhitened_extremum = unwhitened_template[unwhitened_extremum_absargmax] + unwhitened_template = unwhitened_templates[unit_ind, :, :] + # unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) + # note: we don't `abs` the extrema so that the amps have the expected sign + # unwhitened_extremum = unwhitened_template[np.unravel_index(unwhitened_extremum_absargmax, unwhitened_template.shape)] + unwhitened_extremum = np.nanmax(np.abs(unwhitened_template)) - # conversion_factor = unwhitened_extremum / whitened_extremum + conversion_factor = unwhitened_extremum / whitened_extremum - # kilosort always has one segment, so always choose 0 segment index - # inds = spike_indices[0][unit_id] - # scaling_factors[inds] = conversion_factor + # kilosort always has one segment, so always choose 0 segment index + inds = spike_indices[0][unit_id] + scaling_factors[inds] = conversion_factor - # scaled_amps = amps_np * scaling_factors * gain_to_uV + offset_to_uV + scaled_amps = neg_amps * scaling_factors - amplitudes_extension.data = {"amplitudes": neg_amps} + amplitudes_extension.data = {"amplitudes": scaled_amps} amplitudes_extension.params = {} amplitudes_extension.run_info = {"run_completed": True} From 2782c69ef1c2ebb27a9926c5017d285d8564397d Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 26 Jan 2026 11:44:55 +0000 Subject: [PATCH 5/9] wip: use pc_features to compute amps for ks4 --- .../extractors/phykilosortextractors.py | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 98e452a2c1..e5f7231f0c 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -431,13 +431,31 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ wPCA = ops.tolist()["wPCA"] # wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") - neg_amps = np.min(np.min(np.einsum("ji,ajk->aik", wPCA, pcs), axis=2), axis=1) + + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") + + whitened_waveforms = _get_unwhitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) + + spike_vector = sorting_analyzer.sorting.to_spike_vector() + unit_indices = spike_vector["unit_index"] + + print(f"{len(whitened_waveforms)=}") + print(f"{len(unit_indices)=}") + + templates = sorting_analyzer.get_extension("templates").get_data() + sparse_templates = sorting_analyzer.sparsity.sparsify_templates(templates) + + scaling_factor = np.zeros(len(unit_indices)) + for spike_index, (unit_index, waveform) in enumerate(zip(unit_indices, whitened_waveforms)): + template = sparse_templates[unit_index, :, :10] + scaling_factor[spike_index] = np.einsum("ij,ij", waveform, template) / np.einsum("ij,ij", template, template) + + neg_amps = np.max(np.max(np.abs(whitened_waveforms), axis=2), axis=1) # neg_amps = np.min(np.min(np.einsum("bc,ji,ajk->aik", wh_inv, wPCA, pcs), axis=2), axis=1) # rescale the amplitudes to physical units, by computing a conversion factor per unit # based on the ratio between the `absmax`s the unwhitened and whitened templates whitened_templates = np.load(kilosort_output_path / "templates.npy") - wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") unwhitened_templates = _compute_unwhitened_templates( whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV ) @@ -462,13 +480,30 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ inds = spike_indices[0][unit_id] scaling_factors[inds] = conversion_factor - scaled_amps = neg_amps * scaling_factors + scaled_amps = scaling_factor * scaling_factors + print("hey") amplitudes_extension.data = {"amplitudes": scaled_amps} amplitudes_extension.params = {} amplitudes_extension.run_info = {"run_completed": True} sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension + template_extension = sorting_analyzer.get_extension("templates") + template_extension.data = {"average": unwhitened_templates} + + +def _get_unwhitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv): + + # sparsity_mask = sorting_analyzer.sparsity.mask + + # spike_vector = sorting_analyzer.sorting.to_spike_vector + + # unit_sparsity_masks = {unit_id: sorting_analyzer.sparsity.mask[unit_index] for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids)} + # inv_white_per_unit = [wh_inv[unit_sparsity_masks[unit_index],:][:,unit_sparsity_masks[unit_index]].shape for unit_index, _ enumerate(sorting_analyzer.unit_ids) + + whitened_waveforms = np.einsum("ji,ajk->aik", wPCA, pcs) + + return whitened_waveforms def _make_locations(sorting_analyzer, kilosort_output_path): @@ -582,6 +617,7 @@ def _compute_unwhitened_templates(whitened_templates, wh_inv, gain_to_uV, offset # templates have dimension (num units) x (num samples) x (num channels) # whitening inverse has dimension (num units) x (num channels) # to undo whitening, we need do matrix multiplication on the channel index + unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates) # then scale to physical units From 37db69f6c19c40645b95cb4e4779d1edf987e2a8 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 26 Jan 2026 14:38:07 +0000 Subject: [PATCH 6/9] wip compute scalings in real space --- .../extractors/phykilosortextractors.py | 79 +++++++++++++------ 1 file changed, 53 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index e5f7231f0c..2361395671 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -384,7 +384,8 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse seed=1205, ) - sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) + # sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) + sparsity = _make_sparsity_from_pcs(recording, sorting, phy_path) sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) @@ -430,51 +431,63 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ ops = np.load(ops_path, allow_pickle=True) wPCA = ops.tolist()["wPCA"] - # wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") + pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") - whitened_waveforms = _get_unwhitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) + whitened_waveforms = _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) + print(f"{whitened_waveforms.shape}") spike_vector = sorting_analyzer.sorting.to_spike_vector() unit_indices = spike_vector["unit_index"] - print(f"{len(whitened_waveforms)=}") - print(f"{len(unit_indices)=}") + whitened_templates = np.load(kilosort_output_path / "templates.npy") + # rescale the amplitudes to physical units, by computing a conversion factor per unit + # based on the ratio between the `absmax`s the unwhitened and whitened templates + unwhitened_templates = _compute_unwhitened_templates( + whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=1.0, offset_to_uV=0.0 + ) + + # print(f"{len(whitened_waveforms)=}") + # print(f"{len(unit_indices)=}") - templates = sorting_analyzer.get_extension("templates").get_data() - sparse_templates = sorting_analyzer.sparsity.sparsify_templates(templates) + # templates = sorting_analyzer.get_extension("templates").get_data() + # print(f"{templates.shape=}") + # sparse_templates = sorting_analyzer.sparsity.sparsify_templates(templates) + wht_inv_per_unit = [ + wh_inv[pc_inds[unit_index], :][:, pc_inds[unit_index]] for unit_index, _ in enumerate(sorting_analyzer.unit_ids) + ] scaling_factor = np.zeros(len(unit_indices)) for spike_index, (unit_index, waveform) in enumerate(zip(unit_indices, whitened_waveforms)): - template = sparse_templates[unit_index, :, :10] - scaling_factor[spike_index] = np.einsum("ij,ij", waveform, template) / np.einsum("ij,ij", template, template) + unwhitened_template = unwhitened_templates[unit_index, :, pc_inds[unit_index]] + # whitened_template = whitened_templates[unit_index, :, pc_inds[unit_index]] + unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], waveform) + scaling_factor[spike_index] = np.einsum("ij,ji", unwhitened_waveform, unwhitened_template) / np.einsum( + "ij,ij", unwhitened_template, unwhitened_template + ) neg_amps = np.max(np.max(np.abs(whitened_waveforms), axis=2), axis=1) # neg_amps = np.min(np.min(np.einsum("bc,ji,ajk->aik", wh_inv, wPCA, pcs), axis=2), axis=1) - # rescale the amplitudes to physical units, by computing a conversion factor per unit - # based on the ratio between the `absmax`s the unwhitened and whitened templates - whitened_templates = np.load(kilosort_output_path / "templates.npy") - unwhitened_templates = _compute_unwhitened_templates( - whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV - ) - if True: spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() scaling_factors = np.zeros(num_spikes) for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - whitened_template = whitened_templates[unit_ind, :, :] - whitened_extremum = np.nanmax(np.abs(whitened_template)) + whitened_template = whitened_templates[unit_ind, :, pc_inds[unit_ind]] + whitened_extremum = np.nanmax(whitened_template) - np.nanmin(whitened_template) - unwhitened_template = unwhitened_templates[unit_ind, :, :] + unwhitened_template = unwhitened_templates[unit_ind, :, pc_inds[unit_ind]] # unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) # note: we don't `abs` the extrema so that the amps have the expected sign # unwhitened_extremum = unwhitened_template[np.unravel_index(unwhitened_extremum_absargmax, unwhitened_template.shape)] - unwhitened_extremum = np.nanmax(np.abs(unwhitened_template)) + unwhitened_extremum = np.nanmax(unwhitened_template) - np.nanmin(unwhitened_template) + + # print(f"{whitened_extremum=}, {unwhitened_extremum=}") - conversion_factor = unwhitened_extremum / whitened_extremum + conversion_factor = unwhitened_extremum # / whitened_extremum + # print(f"unit {unit_id} conversion factor: {conversion_factor}") # kilosort always has one segment, so always choose 0 segment index inds = spike_indices[0][unit_id] @@ -483,16 +496,14 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ scaled_amps = scaling_factor * scaling_factors print("hey") - amplitudes_extension.data = {"amplitudes": scaled_amps} + amplitudes_extension.data = {"amplitudes": scaling_factor * scaling_factors * gain_to_uV} amplitudes_extension.params = {} amplitudes_extension.run_info = {"run_completed": True} sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension - template_extension = sorting_analyzer.get_extension("templates") - template_extension.data = {"average": unwhitened_templates} -def _get_unwhitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv): +def _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv): # sparsity_mask = sorting_analyzer.sparsity.mask @@ -557,6 +568,22 @@ def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) +def _make_sparsity_from_pcs(recording, sorting, kilosort_output_path): + """Constructs the `ChannelSparsity` of from kilosort output, by seeing if the + templates output is zero or not on all channels.""" + + pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") + unit_ids_to_channel_ids = { + unit_id: recording.channel_ids[pc_inds[unit_index]] for unit_index, unit_id in enumerate(sorting.unit_ids) + } + sparsity = ChannelSparsity.from_unit_id_to_channel_ids( + unit_ids_to_channel_ids, sorting.unit_ids, recording.channel_ids + ) + + # return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) + return sparsity + + def _make_templates( sorting_analyzer, kilosort_output_path, mask, sampling_frequency, gain_to_uV, offset_to_uV, unwhiten=True ): @@ -615,7 +642,7 @@ def _compute_unwhitened_templates(whitened_templates, wh_inv, gain_to_uV, offset applying an inverse whitening matrix.""" # templates have dimension (num units) x (num samples) x (num channels) - # whitening inverse has dimension (num units) x (num channels) + # whitening inverse has dimension (num channels) x (num channels) # to undo whitening, we need do matrix multiplication on the channel index unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates) From 651226a64f13147e1712a7c9545a5ca2f9d41fd9 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 28 Jan 2026 15:54:53 +0000 Subject: [PATCH 7/9] wip: add pcs and waveforms --- .../extractors/phykilosortextractors.py | 87 ++++++++++++++++--- 1 file changed, 77 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 2361395671..39c9ad7ba1 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,4 +1,7 @@ from __future__ import annotations +from spikeinterface.postprocessing.principal_component import ComputePrincipalComponents +from spikeinterface.preprocessing.scale import scale_to_uV +from spikeinterface.core.analyzer_extension_core import ComputeWaveforms from pandas.tests.tseries.offsets.test_business_day import offset from typing import Optional @@ -390,18 +393,77 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) # first compute random spikes. These do nothing, but are needed for si-gui to run - sorting_analyzer.compute("random_spikes") + sorting_analyzer.compute("random_spikes", method="all") _make_templates( sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, gain_to_uV, offset_to_uV, unwhiten=unwhiten ) _make_locations(sorting_analyzer, phy_path) _make_amplitudes(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) + _make_waveforms(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) + _make_principal_components(sorting_analyzer, phy_path) sorting_analyzer._recording = None return sorting_analyzer +def _make_principal_components(sorting_analyzer, kilosort_output_path): + + pcs_extension = ComputePrincipalComponents(sorting_analyzer) + + pcs = np.load(Path(kilosort_output_path) / "pc_features.npy") + + print(f"{pcs.shape=}") + + pcs_extension.data = {"pca_projection": pcs} + pcs_extension.params = {"n_components": 6, "mode": "by_channel_local", "whiten": True, "dtype": "float32"} + pcs_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["principal_components"] = pcs_extension + + +def _make_waveforms(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_to_uV): + + waveforms_extension = ComputeWaveforms(sorting_analyzer) + + pcs = np.load(Path(kilosort_output_path) / "pc_features.npy") + ops_path = kilosort_output_path / "ops.npy" + if ops_path.is_file(): + ops = np.load(ops_path, allow_pickle=True) + wPCA = ops.tolist()["wPCA"] + pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") + whitened_waveforms = _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) + + spike_vector = sorting_analyzer.sorting.to_spike_vector() + + wht_inv_per_unit = [ + wh_inv[pc_inds[unit_index], :][:, pc_inds[unit_index]] for unit_index, _ in enumerate(sorting_analyzer.unit_ids) + ] + + order_per_unit_index = {} + for unit_index, _ in enumerate(sorting_analyzer.unit_ids): + channel_order = sorting_analyzer.recording.ids_to_indices( + sorting_analyzer.channel_ids[sorting_analyzer.sparsity.mask[unit_index]] + ) + order_per_unit_index[unit_index] = np.array( + [list(pc_inds[unit_index]).index(channel_ind) for channel_ind in channel_order] + ) + + correct_waveforms = np.empty_like(whitened_waveforms) + for waveform_index, (whitened_waveform, unit_index) in enumerate( + zip(whitened_waveforms, spike_vector["unit_index"]) + ): + unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], whitened_waveform) + correct_waveforms[waveform_index, :, :] = unwhitened_waveform[:, order_per_unit_index[unit_index]] + + waveforms_extension.data = {"waveforms": correct_waveforms * gain_to_uV + offset_to_uV} + waveforms_extension.params = {} + waveforms_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["waveforms"] = waveforms_extension + + def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_to_uV): """Constructs approximate `spike_amplitudes` extension from the amplitudes numpy array in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" @@ -430,13 +492,9 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ if ops_path.is_file(): ops = np.load(ops_path, allow_pickle=True) wPCA = ops.tolist()["wPCA"] - pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") - wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") - whitened_waveforms = _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) - print(f"{whitened_waveforms.shape}") spike_vector = sorting_analyzer.sorting.to_spike_vector() unit_indices = spike_vector["unit_index"] @@ -455,6 +513,14 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ # print(f"{templates.shape=}") # sparse_templates = sorting_analyzer.sparsity.sparsify_templates(templates) + order_per_unit_index = {} + for unit_index, _ in enumerate(sorting_analyzer.unit_ids): + channel_order = sorting_analyzer.recording.ids_to_indices( + sorting_analyzer.channel_ids[sorting_analyzer.sparsity.mask[unit_index]] + ) + order_per_unit_index[unit_index] = np.array( + [list(pc_inds[unit_index]).index(channel_ind) for channel_ind in channel_order] + ) wht_inv_per_unit = [ wh_inv[pc_inds[unit_index], :][:, pc_inds[unit_index]] for unit_index, _ in enumerate(sorting_analyzer.unit_ids) ] @@ -463,9 +529,11 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ unwhitened_template = unwhitened_templates[unit_index, :, pc_inds[unit_index]] # whitened_template = whitened_templates[unit_index, :, pc_inds[unit_index]] unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], waveform) - scaling_factor[spike_index] = np.einsum("ij,ji", unwhitened_waveform, unwhitened_template) / np.einsum( - "ij,ij", unwhitened_template, unwhitened_template - ) + # correct_waveform = unwhitened_waveform[:,order_per_unit_index[unit_index]] + scaling_factor[spike_index] = np.min(unwhitened_waveform) + # scaling_factor[spike_index] = np.einsum("ij,ji", correct_waveform, unwhitened_template) / np.einsum( + # "ij,ij", unwhitened_template, unwhitened_template + # ) neg_amps = np.max(np.max(np.abs(whitened_waveforms), axis=2), axis=1) # neg_amps = np.min(np.min(np.einsum("bc,ji,ajk->aik", wh_inv, wPCA, pcs), axis=2), axis=1) @@ -494,9 +562,8 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ scaling_factors[inds] = conversion_factor scaled_amps = scaling_factor * scaling_factors - print("hey") - amplitudes_extension.data = {"amplitudes": scaling_factor * scaling_factors * gain_to_uV} + amplitudes_extension.data = {"amplitudes": scaling_factor * gain_to_uV} # * scaling_factors * gain_to_uV} amplitudes_extension.params = {} amplitudes_extension.run_info = {"run_completed": True} From ac343bb7a8a1e2c7b1bc8c2ae2eb537476a34c83 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 9 Apr 2026 17:29:03 +0100 Subject: [PATCH 8/9] update spike amps comp --- .../extractors/phykilosortextractors.py | 163 ++++++++++-------- 1 file changed, 91 insertions(+), 72 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 39c9ad7ba1..3e33816fef 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -322,7 +322,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") -def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer: +def read_kilosort_as_analyzer( + folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None, recording=None +) -> SortingAnalyzer: """ Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and above are supported. The function may work on older versions of Kilosort output, @@ -378,19 +380,22 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse else: AssertionError(f"Cannot read probe layout from folder {phy_path}.") - # to make the initial analyzer, we'll use a fake recording and set it to None later - recording, _ = generate_ground_truth_recording( - probe=probe, - sampling_frequency=sampling_frequency, - durations=[duration], - num_units=1, - seed=1205, - ) + if recording is None: + # to make the initial analyzer, we'll use a fake recording and set it to None later + recording_for_analyzer, _ = generate_ground_truth_recording( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + num_units=1, + seed=1205, + ) + else: + recording_for_analyzer = recording # sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) - sparsity = _make_sparsity_from_pcs(recording, sorting, phy_path) + sparsity = _make_sparsity_from_pcs(recording_for_analyzer, sorting, phy_path) - sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) + sorting_analyzer = create_sorting_analyzer(sorting, recording_for_analyzer, sparse=True, sparsity=sparsity) # first compute random spikes. These do nothing, but are needed for si-gui to run sorting_analyzer.compute("random_spikes", method="all") @@ -400,10 +405,12 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse ) _make_locations(sorting_analyzer, phy_path) _make_amplitudes(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) - _make_waveforms(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) - _make_principal_components(sorting_analyzer, phy_path) + # _make_waveforms(sorting_analyzer, phy_path, gain_to_uV, offset_to_uV) + # _make_principal_components(sorting_analyzer, phy_path) + + if recording is None: + sorting_analyzer._recording = None - sorting_analyzer._recording = None return sorting_analyzer @@ -416,7 +423,7 @@ def _make_principal_components(sorting_analyzer, kilosort_output_path): print(f"{pcs.shape=}") pcs_extension.data = {"pca_projection": pcs} - pcs_extension.params = {"n_components": 6, "mode": "by_channel_local", "whiten": True, "dtype": "float32"} + pcs_extension.params = {"n_components": 6, "mode": "by_channel_local_kilosort", "whiten": True, "dtype": "float32"} pcs_extension.run_info = {"run_completed": True} sorting_analyzer.extensions["principal_components"] = pcs_extension @@ -487,83 +494,96 @@ def _make_amplitudes(sorting_analyzer, kilosort_output_path, gain_to_uV, offset_ ) return + spike_vector = sorting_analyzer.sorting.to_spike_vector() + unit_indices = spike_vector["unit_index"] + pcs = np.load(Path(kilosort_output_path) / "pc_features.npy") ops_path = kilosort_output_path / "ops.npy" if ops_path.is_file(): ops = np.load(ops_path, allow_pickle=True) wPCA = ops.tolist()["wPCA"] + pc_inds = np.load(kilosort_output_path / "pc_feature_ind.npy") + + # this is the inverse whitening matrix for each channel wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") - whitened_waveforms = _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) - spike_vector = sorting_analyzer.sorting.to_spike_vector() - unit_indices = spike_vector["unit_index"] + # since each unit has a channel sparsity, we need different inverse whitening matrices for each unit + wht_inv_per_unit = [ + wh_inv[pc_inds[unit_index], :][:, pc_inds[unit_index]] for unit_index, _ in enumerate(sorting_analyzer.unit_ids) + ] - whitened_templates = np.load(kilosort_output_path / "templates.npy") - # rescale the amplitudes to physical units, by computing a conversion factor per unit - # based on the ratio between the `absmax`s the unwhitened and whitened templates - unwhitened_templates = _compute_unwhitened_templates( - whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=1.0, offset_to_uV=0.0 - ) + min_of_waveforms = np.zeros(len(spike_vector)) + for spike_index, (unit_index, pc) in enumerate(zip(unit_indices, pcs, strict=True)): + whitened_waveform = np.einsum("ji,jk->ik", wPCA, pc) + unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], whitened_waveform) + min_of_waveforms[spike_index] = np.min(unwhitened_waveform) - # print(f"{len(whitened_waveforms)=}") - # print(f"{len(unit_indices)=}") + # WIP code for amplitude scalings - # templates = sorting_analyzer.get_extension("templates").get_data() - # print(f"{templates.shape=}") - # sparse_templates = sorting_analyzer.sparsity.sparsify_templates(templates) + # whitened_waveforms = _get_whitened_waveforms(wPCA, pcs, sorting_analyzer, wh_inv) - order_per_unit_index = {} - for unit_index, _ in enumerate(sorting_analyzer.unit_ids): - channel_order = sorting_analyzer.recording.ids_to_indices( - sorting_analyzer.channel_ids[sorting_analyzer.sparsity.mask[unit_index]] - ) - order_per_unit_index[unit_index] = np.array( - [list(pc_inds[unit_index]).index(channel_ind) for channel_ind in channel_order] - ) - wht_inv_per_unit = [ - wh_inv[pc_inds[unit_index], :][:, pc_inds[unit_index]] for unit_index, _ in enumerate(sorting_analyzer.unit_ids) - ] - scaling_factor = np.zeros(len(unit_indices)) - for spike_index, (unit_index, waveform) in enumerate(zip(unit_indices, whitened_waveforms)): - unwhitened_template = unwhitened_templates[unit_index, :, pc_inds[unit_index]] - # whitened_template = whitened_templates[unit_index, :, pc_inds[unit_index]] - unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], waveform) - # correct_waveform = unwhitened_waveform[:,order_per_unit_index[unit_index]] - scaling_factor[spike_index] = np.min(unwhitened_waveform) - # scaling_factor[spike_index] = np.einsum("ij,ji", correct_waveform, unwhitened_template) / np.einsum( - # "ij,ij", unwhitened_template, unwhitened_template - # ) - - neg_amps = np.max(np.max(np.abs(whitened_waveforms), axis=2), axis=1) + # for spike_index, (unit_index, waveform) in enumerate(zip(unit_indices, whitened_waveforms)): + # unwhitened_waveform = np.einsum("ij,kj->ki", wht_inv_per_unit[unit_index], waveform) + # min_of_waveforms[spike_index] = np.min(unwhitened_waveform) + + # for spike_index, (unit_index, waveform) in enumerate(zip(unit_indices, whitened_waveforms)): + # unwhitened_template = unwhitened_templates[unit_index, :, pc_inds[unit_index]] + # whitened_template = whitened_templates[unit_index, :, pc_inds[unit_index]] + # correct_waveform = unwhitened_waveform[:,order_per_unit_index[unit_index]] + + # scaling_factor[spike_index] = np.einsum("ij,ji", correct_waveform, unwhitened_template) / np.einsum( + # "ij,ij", unwhitened_template, unwhitened_template + # ) + + # whitened_templates = np.load(kilosort_output_path / "templates.npy") + + # rescale the amplitudes to physical units, by computing a conversion factor per unit + # based on the ratio between the `absmax`s the unwhitened and whitened templates + # unwhitened_templates = _compute_unwhitened_templates( + # whitened_templates=whitened_templates, wh_inv=wh_inv, gain_to_uV=1.0, offset_to_uV=0.0 + # ) + + # order_per_unit_index = {} + # for unit_index, _ in enumerate(sorting_analyzer.unit_ids): + # channel_order = sorting_analyzer.recording.ids_to_indices( + # sorting_analyzer.channel_ids[sorting_analyzer.sparsity.mask[unit_index]] + # ) + # order_per_unit_index[unit_index] = np.array( + # [list(pc_inds[unit_index]).index(channel_ind) for channel_ind in channel_order] + # ) + + # neg_amps = np.max(np.max(np.abs(whitened_waveforms), axis=2), axis=1) # neg_amps = np.min(np.min(np.einsum("bc,ji,ajk->aik", wh_inv, wPCA, pcs), axis=2), axis=1) - if True: - spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() - scaling_factors = np.zeros(num_spikes) - for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): + # if False: + # spike_indices = sorting_analyzer.sorting.get_spike_vector_to_indices() + # scaling_factors = np.zeros(num_spikes) + # for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - whitened_template = whitened_templates[unit_ind, :, pc_inds[unit_ind]] - whitened_extremum = np.nanmax(whitened_template) - np.nanmin(whitened_template) + # whitened_template = whitened_templates[unit_ind, :, pc_inds[unit_ind]] + # whitened_extremum = np.nanmax(whitened_template) - np.nanmin(whitened_template) - unwhitened_template = unwhitened_templates[unit_ind, :, pc_inds[unit_ind]] - # unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) - # note: we don't `abs` the extrema so that the amps have the expected sign - # unwhitened_extremum = unwhitened_template[np.unravel_index(unwhitened_extremum_absargmax, unwhitened_template.shape)] - unwhitened_extremum = np.nanmax(unwhitened_template) - np.nanmin(unwhitened_template) + # unwhitened_template = unwhitened_templates[unit_ind, :, pc_inds[unit_ind]] + # # unwhitened_extremum_absargmax = np.argmax(np.abs(unwhitened_template), keepdims=True) + # # note: we don't `abs` the extrema so that the amps have the expected sign + # # unwhitened_extremum = unwhitened_template[np.unravel_index(unwhitened_extremum_absargmax, unwhitened_template.shape)] + # unwhitened_extremum = np.nanmax(unwhitened_template) - np.nanmin(unwhitened_template) - # print(f"{whitened_extremum=}, {unwhitened_extremum=}") + # # print(f"{whitened_extremum=}, {unwhitened_extremum=}") - conversion_factor = unwhitened_extremum # / whitened_extremum - # print(f"unit {unit_id} conversion factor: {conversion_factor}") + # conversion_factor = unwhitened_extremum # / whitened_extremum + # # print(f"unit {unit_id} conversion factor: {conversion_factor}") - # kilosort always has one segment, so always choose 0 segment index - inds = spike_indices[0][unit_id] - scaling_factors[inds] = conversion_factor + # # kilosort always has one segment, so always choose 0 segment index + # inds = spike_indices[0][unit_id] + # scaling_factors[inds] = conversion_factor - scaled_amps = scaling_factor * scaling_factors + # scaled_amps = scaling_factor * scaling_factors - amplitudes_extension.data = {"amplitudes": scaling_factor * gain_to_uV} # * scaling_factors * gain_to_uV} + amplitudes_extension.data = { + "amplitudes": min_of_waveforms * gain_to_uV + offset_to_uV + } # * scaling_factors * gain_to_uV} amplitudes_extension.params = {} amplitudes_extension.run_info = {"run_completed": True} @@ -647,7 +667,6 @@ def _make_sparsity_from_pcs(recording, sorting, kilosort_output_path): unit_ids_to_channel_ids, sorting.unit_ids, recording.channel_ids ) - # return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) return sparsity From 139b328360cb8f0a356349627bbc61c309c0c823 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 16:43:59 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/phykilosortextractors.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 27231a91af..3e2d6583c0 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,4 +1,3 @@ - from spikeinterface.postprocessing.principal_component import ComputePrincipalComponents from spikeinterface.preprocessing.scale import scale_to_uV from spikeinterface.core.analyzer_extension_core import ComputeWaveforms @@ -326,7 +325,6 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove def read_kilosort_as_analyzer( folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None, recording=None ) -> SortingAnalyzer: - """ Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and above are supported. The function may work on older versions of Kilosort output, @@ -662,6 +660,7 @@ def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): mask = np.sum(np.abs(templates), axis=1) != 0 return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) + def _make_sparsity_from_pcs(recording, sorting, kilosort_output_path): """Constructs the `ChannelSparsity` of from kilosort output, by seeing if the templates output is zero or not on all channels."""