diff --git a/doc/how_to/auto_label_units.rst b/doc/how_to/auto_label_units.rst index 96f4cac8f8..3f9c74f37c 100644 --- a/doc/how_to/auto_label_units.rst +++ b/doc/how_to/auto_label_units.rst @@ -183,7 +183,7 @@ file. .. code:: ipython3 - bombcell_labels = sc.bombcell_label_units(sorting_analyzer, thresholds=bombcell_default_thresholds, label_non_somatic=True, split_non_somatic_good_mua=True) + bombcell_labels = sc.bombcell_label_units(sorting_analyzer, thresholds=bombcell_default_thresholds, split_non_somatic=True) .. code:: ipython3 diff --git a/examples/how_to/auto_label_units.py b/examples/how_to/auto_label_units.py index a45a31f613..aa0f047124 100644 --- a/examples/how_to/auto_label_units.py +++ b/examples/how_to/auto_label_units.py @@ -92,7 +92,7 @@ pprint(bombcell_default_thresholds) # %% -bombcell_labels = sc.bombcell_label_units(sorting_analyzer, thresholds=bombcell_default_thresholds, label_non_somatic=True, split_non_somatic_good_mua=True) +bombcell_labels = sc.bombcell_label_units(sorting_analyzer, thresholds=bombcell_default_thresholds, split_non_somatic=True) # %% bombcell_labels["bombcell_label"].value_counts() diff --git a/examples/how_to/bombcell_pipeline_example.py b/examples/how_to/bombcell_pipeline_example.py new file mode 100644 index 0000000000..c33fcacf4d --- /dev/null +++ b/examples/how_to/bombcell_pipeline_example.py @@ -0,0 +1,259 @@ +""" +bombcell quality control example. + +This example shows how to run bombcell quality control on a SortingAnalyzer +with customizable parameters and thresholds. + +Prerequisites: a SortingAnalyzer with waveforms, templates, template_metrics, +and other extensions already computed. See SpikeInterface documentation for +preprocessing, spike sorting, and analyzer creation. +""" + +# %% +import spikeinterface.full as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +from pathlib import Path +from pprint import pprint + +# %% Paths - edit these to match your data +analyzer_folder = Path("/path/to/your/sorting_analyzer.zarr") +output_folder = Path("/path/to/your/bombcell_output") + +# %% Load existing SortingAnalyzer +# The analyzer should already have: random_spikes, waveforms, templates, +# noise_levels, unit_locations, spike_locations, template_metrics. +# amplitude_scalings is computed on-demand if needed (for amplitude_cutoff or valid periods). +# See SpikeInterface docs for preprocessing, sorting, and analyzer creation. +analyzer = si.load_sorting_analyzer(analyzer_folder) + +# %% QC parameters +# Get defaults and modify as needed +qc_params = sc.get_default_qc_params() + +# --- Metrics to compute --- +qc_params["compute_amplitude_cutoff"] = True # estimate missing spikes (requires spike_amplitudes) +qc_params["compute_drift"] = True # compute drift metrics (position changes over time) +qc_params["compute_distance_metrics"] = False # isolation distance & L-ratio - slow, not drift-robust + # recommend True for stable/chronic recordings + +# --- bombcell classification options --- +# Note: the refractory-period-violation method (sliding_rp_violation vs +# rp_contamination) is selected below in thresholds["mua"] — that is the ONE +# place to pick the method. To tune its metric-specific params, use +# qc_params["metric_params"] (see bottom of this section). +qc_params["split_non_somatic"] = False # if True, non-somatic split into good/mua subcategories + # (to skip non-somatic labeling entirely, clear the + # thresholds["non-somatic"] section below) +qc_params["compute_valid_periods"] = False # if True, compute valid_unit_periods and then compute + # quality metrics restricted to those periods + +# --- Presence ratio --- +qc_params["presence_ratio_bin_duration_s"] = 60 # bin size (s) for checking if unit fires throughout recording + +# --- Drift parameters --- +qc_params["drift_interval_s"] = 60 # time bin (s) for computing position over time +qc_params["drift_min_spikes"] = 100 # min spikes in bin to estimate position (skip if fewer) + +# --- Plotting --- +qc_params["plot_histograms"] = True # save histogram plots of all metrics +qc_params["plot_waveforms"] = True # save waveform plots for each unit +qc_params["plot_upset"] = True # save UpSet plot showing threshold failure combinations + +# --- Custom metric names / params (optional) --- +# To bypass the compute_* flags and specify exactly which metrics to compute, +# set qc_params["metric_names"] to a list. Any SpikeInterface quality metric works. +# qc_params["metric_names"] = [ +# "amplitude_median", "snr", "num_spikes", "presence_ratio", "firing_rate", +# "sliding_rp_violation", "drift", "silhouette", +# ] +# +# To override metric-specific params, set qc_params["metric_params"]: +# qc_params["metric_params"] = { +# "silhouette": {"method": "simplified"}, +# "drift": {"interval_s": 30}, +# } + +# %% Classification thresholds +# Format: {"greater": min_value, "less": max_value} - unit passes if min < value < max +# Use None to disable a bound. Add "abs": True to use absolute value. +thresholds = sc.bombcell_get_default_thresholds() + +# --- Noise thresholds (waveform quality) --- +# Units failing ANY of these are labeled "noise" (not neural signals) +thresholds["noise"]["num_positive_peaks"] = {"greater": None, "less": 2} +thresholds["noise"]["num_negative_peaks"] = {"greater": None, "less": 1} +thresholds["noise"]["peak_to_trough_duration"] = {"greater": 0.0001, "less": 0.00115} +thresholds["noise"]["waveform_baseline_flatness"] = {"greater": None, "less": 0.5} +thresholds["noise"]["peak_after_to_trough_ratio"] = {"greater": None, "less": 0.8} +thresholds["noise"]["exp_decay"] = {"greater": 0.01, "less": 0.1} + +# --- MUA thresholds (spike quality) --- +# Units failing ANY of these (that passed noise) are labeled "mua" (multi-unit activity) +thresholds["mua"]["amplitude_median"] = {"greater": 30, "less": None, "abs": True} +thresholds["mua"]["snr"] = {"greater": 5, "less": None} +thresholds["mua"]["amplitude_cutoff"] = {"greater": None, "less": 0.2} +thresholds["mua"]["num_spikes"] = {"greater": 300, "less": None} +thresholds["mua"]["sliding_rp_violation"] = {"greater": None, "less": 0.1} +thresholds["mua"]["presence_ratio"] = {"greater": 0.7, "less": None} +thresholds["mua"]["drift_ptp"] = {"greater": None, "less": 100} + +# Optional distance metrics (only used if compute_distance_metrics=True) +# thresholds["mua"]["isolation_distance"] = {"greater": 20, "less": None} +# thresholds["mua"]["l_ratio"] = {"greater": None, "less": 0.3} + +# --- Non-somatic thresholds (waveform shape) --- +# Detects axonal/dendritic units based on waveform features +thresholds["non-somatic"]["peak_before_to_trough_ratio"] = {"greater": None, "less": 3} +thresholds["non-somatic"]["peak_before_width"] = {"greater": 0.00015, "less": None} +thresholds["non-somatic"]["trough_width"] = {"greater": 0.0002, "less": None} +thresholds["non-somatic"]["peak_before_to_peak_after_ratio"] = {"greater": None, "less": 3} +thresholds["non-somatic"]["main_peak_to_trough_ratio"] = {"greater": None, "less": 0.8} + +# %% Adding custom metrics +# You can add ANY metric from the SortingAnalyzer's quality_metrics or +# template_metrics DataFrame to ANY threshold section (noise, mua, non-somatic). +# +# Metrics in "noise" section: unit fails if ANY threshold is violated -> labeled "noise" +# Metrics in "mua" section: unit fails if ANY threshold is violated -> labeled "mua" +# Metrics in "non-somatic" section: OR'd with built-in waveform shape checks +# Metrics that haven't been computed are automatically skipped (with a warning) +# +# Examples: +# thresholds["mua"]["firing_rate"] = {"greater": 0.1, "less": None} +# thresholds["noise"]["half_width"] = {"greater": 0.05e-3, "less": 0.6e-3} +# thresholds["non-somatic"]["velocity_above"] = {"greater": 2.0, "less": None} +# +# To DISABLE an existing threshold: +# thresholds["mua"]["drift_ptp"] = {"greater": None, "less": None} + +pprint(thresholds) + +# %% Run bombcell QC +# This computes quality metrics and classifies units as good/mua/noise/non-somatic. +# Both `params` and `thresholds` also accept a path to a JSON file: +# e.g. params="qc_params.json", thresholds="thresholds.json". +# After each run, the thresholds and bombcell-specific config are saved to +# output_folder as thresholds.json and bombcell_config.json for reproducibility. +# +# Rerun flags force recomputation of specific extensions (all default False): +# rerun_quality_metrics - quality_metrics +# rerun_pca - principal_components (for distance metrics) +# rerun_amplitude_scalings - amplitude_scalings (prerequisite for amplitude_cutoff and valid periods) +labels, metrics, figures = sc.run_bombcell_qc( + sorting_analyzer=analyzer, + output_folder=output_folder, + params=qc_params, + thresholds=thresholds, + rerun_quality_metrics=False, + n_jobs=-1, +) + +# %% Results +print(f"\nResults saved to: {output_folder}") +print(f"\nLabel distribution:\n{labels['bombcell_label'].value_counts()}") + +good_units = labels[labels["bombcell_label"] == "good"].index.tolist() +mua_units = labels[labels["bombcell_label"] == "mua"].index.tolist() +noise_units = labels[labels["bombcell_label"] == "noise"].index.tolist() +non_soma_units = labels[labels["bombcell_label"] == "non_soma"].index.tolist() + +print(f"\nGood units ({len(good_units)}): {good_units[:10]}...") +print(f"MUA units ({len(mua_units)}): {mua_units[:10]}...") + +# %% Visualize: template peaks and troughs +_ = sw.plot_template_peak_trough( + analyzer, + unit_ids=analyzer.unit_ids[:8], + n_channels_around=2, + unit_labels=labels["bombcell_label"], + figsize=(20, 12), +) + +# %% Using valid time periods +# Valid periods identify chunks of time where each unit has stable amplitude +# and low refractory period violations. Quality metrics computed on those +# chunks are more representative for units that drop out or drift during the +# recording. This is useful when recordings have unstable periods (e.g., drift, +# probe movement, or electrode noise). +# +# There are two ways to enable this, depending on how much control you want. +# +# --- Option A: let the pipeline handle it (simple case, defaults) --- +# qc_params["compute_valid_periods"] = True +# labels, metrics, figures = sc.run_bombcell_qc(analyzer, params=qc_params, ...) +# The pipeline will: +# 1. compute valid_unit_periods with default fp/fn thresholds +# 2. compute quality_metrics with use_valid_periods=True +# 3. hand the resulting metrics to bombcell for labeling +# +# --- Option B: compute valid periods yourself first (recommended for tuning) --- +# This is the explicit route: you decide the fp/fn thresholds, period mode, etc., +# and bombcell just reads the resulting metrics. Recommended because it makes +# the "what was the fp threshold?" question unambiguous instead of hidden. +# +# analyzer.compute("amplitude_scalings") # prerequisite +# analyzer.compute( +# "valid_unit_periods", +# fp_threshold=0.1, # should line up with your bombcell RPV threshold +# fn_threshold=0.1, # should line up with your bombcell amplitude_cutoff threshold +# period_duration_s_absolute=30.0, +# period_target_num_spikes=300, +# period_mode="absolute", +# minimum_valid_period_duration=180, +# ) +# qc_params["compute_valid_periods"] = True # tell the pipeline to use valid_periods +# # (the pipeline sees the extension already exists and reuses it as-is; it will +# # warn you if its fp/fn don't match your bombcell thresholds) +# labels, metrics, figures = sc.run_bombcell_qc(analyzer, params=qc_params, ...) +# +# After running, the per-unit valid periods live on the analyzer. Access with: +# valid_periods = analyzer.get_extension("valid_unit_periods").get_data() + +# %% Using bombcell_label_units directly (without the pipeline) +# If you want more control, you can call bombcell_label_units directly. +# This skips quality metric computation, plotting, and saving — you handle those yourself. + +# Basic usage: just pass the analyzer and thresholds +labels_direct = sc.bombcell_label_units( + sorting_analyzer=analyzer, + thresholds=thresholds, +) + +# With external metrics (e.g. from a CSV or custom computation): +# import pandas as pd +# my_metrics = pd.read_csv("my_metrics.csv", index_col=0) +# labels_direct = sc.bombcell_label_units( +# external_metrics=my_metrics, +# thresholds=thresholds, +# ) + +# Note: bombcell_label_units is a pure labeler — it does not compute or +# recompute any extension. If you want valid-periods-aware labels, compute +# valid_unit_periods and quality_metrics(use_valid_periods=True) yourself +# before calling bombcell_label_units (see "Option B" above). + +# Thresholds can also be loaded from a JSON file: +# labels_direct = sc.bombcell_label_units( +# sorting_analyzer=analyzer, +# thresholds="my_thresholds.json", +# ) + +# %% Parameter tuning by recording type +# +# Chronic recordings: +# - Distance metrics work well (stable recordings = reliable isolation_distance/l_ratio) +# - Set compute_distance_metrics = True +# - Drift is typically minimal, so drift metrics are not very informative +# +# Acute recordings: +# - Distance metrics unreliable (drift artificially lowers isolation_distance/l_ratio) +# - Keep compute_distance_metrics = False (default) +# - Keep drift threshold strict +# +# Cerebellum: +# - Complex spikes may trigger noise detection; relax num_positive_peaks +# +# Striatum: +# - MSNs: lower spike count and presence ratio thresholds diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 16bdddd870..552b999bc4 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -30,3 +30,7 @@ bombcell_label_units, save_bombcell_results, ) +from .bombcell_pipeline import ( + get_default_qc_params, + run_bombcell_qc, +) diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index 9e0b17632e..61ffa34199 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -9,6 +9,7 @@ """ import operator +import warnings from pathlib import Path import json import numpy as np @@ -30,9 +31,11 @@ "snr", "amplitude_cutoff", "num_spikes", - "rp_contamination", + "sliding_rp_violation", "presence_ratio", "drift_ptp", + "isolation_distance", + "l_ratio", ] DEFAULT_NON_SOMATIC_METRICS = [ @@ -43,13 +46,28 @@ "main_peak_to_trough_ratio", ] +# Metrics belonging to the built-in non-somatic groups. +# The compound logic is: (width_group AND ratio_group) OR main_peak_group +# Any metric in the "non-somatic" threshold section that is NOT listed here +# is treated as a standalone condition OR'd into the final result. +_NON_SOMATIC_WIDTH_GROUP = {"peak_before_width", "trough_width"} +_NON_SOMATIC_RATIO_GROUP = {"peak_before_to_trough_ratio", "peak_before_to_peak_after_ratio"} +_NON_SOMATIC_MAIN_PEAK_GROUP = {"main_peak_to_trough_ratio"} +_NON_SOMATIC_BUILTIN_METRICS = _NON_SOMATIC_WIDTH_GROUP | _NON_SOMATIC_RATIO_GROUP | _NON_SOMATIC_MAIN_PEAK_GROUP + def bombcell_get_default_thresholds() -> dict: """ bombcell - Returns default thresholds for unit labeling. Each metric has 'greater' and 'less' values. Use None to disable a threshold (e.g. to ignore a metric completely - or to only have a greater or a less threshold) + or to only have a greater or a less threshold). + + Refractory-period violations: the "mua" section must contain exactly ONE of + ``"sliding_rp_violation"`` or ``"rp_contamination"``. That single key picks + both the RPV method that gets computed and the threshold applied to it — + the pipeline reads this entry to decide which RPV metric to compute. + The default here uses ``"sliding_rp_violation"``. """ # bombcell return { @@ -66,9 +84,11 @@ def bombcell_get_default_thresholds() -> dict: "snr": {"greater": 5, "less": None}, "amplitude_cutoff": {"greater": None, "less": 0.2}, "num_spikes": {"greater": 300, "less": None}, - "rp_contamination": {"greater": None, "less": 0.1}, + "sliding_rp_violation": {"greater": None, "less": 0.1}, "presence_ratio": {"greater": 0.7, "less": None}, "drift_ptp": {"greater": None, "less": 100}, # um + "isolation_distance": {"greater": 20, "less": None}, + "l_ratio": {"greater": None, "less": 0.3}, }, "non-somatic": { "peak_before_to_trough_ratio": {"greater": None, "less": 3}, @@ -83,8 +103,7 @@ def bombcell_get_default_thresholds() -> dict: def bombcell_label_units( sorting_analyzer=None, thresholds: dict | str | Path | None = None, - label_non_somatic: bool = True, - split_non_somatic_good_mua: bool = False, + split_non_somatic: bool = False, external_metrics: "pd.DataFrame | list[pd.DataFrame] | None" = None, ) -> "pd.DataFrame": """ @@ -102,7 +121,9 @@ def bombcell_label_units( Units that pass all noise and MUA thresholds are labeled as "good". 4. NON-SOMATIC: Among units that are not "noise", those that meet non-somatic criteria based on waveform shape are - labeled as "non_soma". + labeled as "non_soma". Non-somatic labeling is active whenever ``thresholds["non-somatic"]`` + is non-empty; skip it entirely by leaving that section empty or omitting it. + These non-somatic criteria include: - Narrow peak and trough widths (using "peak_before_width" and "trough_width" metrics) @@ -111,7 +132,11 @@ def bombcell_label_units( - Large main peak to trough ratio (using "main_peak_to_trough_ratio" metric) If units have a narrow peak and a large ratio OR a large main peak to trough ratio, - they are labeled as non-somatic. If `split_non_somatic_good_mua` is True, non-somatic units are further split + they are labeled as non-somatic. Custom metrics can also be added to the "non-somatic" + threshold section — any metric not part of the built-in groups (width, ratio, main_peak) + is treated as a standalone condition OR'd into the non-somatic detection. + + If `split_non_somatic` is True, non-somatic units are further split into "non_soma_good" and "non_soma_mua", otherwise they are all labeled as "non_soma". Parameters @@ -123,13 +148,29 @@ def bombcell_label_units( Threshold dict or JSON file, including a three sections ("noise", "mua", "non-somatic") of {"metric": {"greater": val, "less": val}}. If None, default Bombcell thresholds are used. - label_non_somatic : bool, default: True - If True, detect non-somatic (dendritic, axonal) units. - split_non_somatic_good_mua : bool, default: False - If True, split non-somatic into "non_soma_good" and "non_soma_mua". + + Refractory-period violation: include exactly ONE of ``"sliding_rp_violation"`` + or ``"rp_contamination"`` under ``thresholds["mua"]``. This single entry + selects BOTH which RPV metric is used AND its threshold — there is no + separate knob for the method choice. ``run_bombcell_qc`` also reads this + entry to decide which RPV metric to compute. + split_non_somatic : bool, default: False + If True, split non-somatic units into "non_soma_good" and "non_soma_mua" + based on whether they pass MUA thresholds. If False, all non-somatic + units are labeled "non_soma". Has no effect if the non-somatic section + of ``thresholds`` is empty (non-somatic labeling is off in that case). external_metrics: "pd.DataFrame | list[pd.DataFrame]" | None = None External metrics DataFrame(s) (index = unit_ids) to use instead of those from SortingAnalyzer. + Notes + ----- + This function is a pure labeler: it reads metrics off the analyzer (or + ``external_metrics``) and applies the thresholds. It does NOT compute or + recompute any extension. If you want quality metrics restricted to valid + unit periods, compute ``valid_unit_periods`` and ``quality_metrics`` with + ``use_valid_periods=True`` yourself before calling this function (or use + ``run_bombcell_qc`` with ``params["compute_valid_periods"]=True``). + Returns ------- labels : pd.DataFrame @@ -143,6 +184,16 @@ def bombcell_label_units( """ import pandas as pd + if thresholds is None: + thresholds_dict = bombcell_get_default_thresholds() + elif isinstance(thresholds, (str, Path)): + with open(thresholds, "r") as f: + thresholds_dict = json.load(f) + elif isinstance(thresholds, dict): + thresholds_dict = thresholds + else: + raise ValueError("thresholds must be a dict, a JSON file path, or None") + if sorting_analyzer is not None: combined_metrics = sorting_analyzer.get_metrics_extension_data() if combined_metrics.empty: @@ -161,15 +212,22 @@ def bombcell_label_units( else: combined_metrics = external_metrics - if thresholds is None: - thresholds_dict = bombcell_get_default_thresholds() - elif isinstance(thresholds, (str, Path)): - with open(thresholds, "r") as f: - thresholds_dict = json.load(f) - elif isinstance(thresholds, dict): - thresholds_dict = thresholds - else: - raise ValueError("thresholds must be a dict, a JSON file path, or None") + # Filter out threshold metrics that are not present in the metrics DataFrame. + # This allows optional metrics (e.g. isolation_distance, l_ratio) to be included + # in the default thresholds without crashing when they haven't been computed. + available_columns = set(combined_metrics.columns) + for section in ("noise", "mua", "non-somatic"): + if section not in thresholds_dict: + continue + missing = [m for m in thresholds_dict[section] if m not in available_columns] + if missing: + warnings.warn( + f"Bombcell thresholds reference metrics not found in the metrics DataFrame " + f"(section '{section}'): {missing}. These will be skipped. " + f"Compute them first if you want them included in the labeling." + ) + for m in missing: + del thresholds_dict[section][m] n_units = len(combined_metrics) @@ -200,8 +258,10 @@ def bombcell_label_units( ) unit_labels.loc[unit_labels.index[non_noise_indices], "label"] = mua_labels["label"].values - if label_non_somatic: - non_somatic_thresholds = thresholds_dict.get("non-somatic", {}) + # Non-somatic labeling is driven by whether the user supplied any thresholds + # in the non-somatic section — no separate on/off flag. + non_somatic_thresholds = thresholds_dict.get("non-somatic", {}) + if len(non_somatic_thresholds) > 0: width_thresholds = { m: non_somatic_thresholds[m] for m in ["peak_before_width", "trough_width"] if m in non_somatic_thresholds } @@ -263,7 +323,24 @@ def bombcell_label_units( # (ratio AND width) OR standalone main_peak_to_trough is_non_somatic = (ratio_conditions & width_conditions) | large_main_peak - if split_non_somatic_good_mua: + # Standalone custom metrics: any metric in non-somatic thresholds that is not + # part of the built-in groups is OR'd in as its own independent condition. + standalone_metrics = { + m: non_somatic_thresholds[m] for m in non_somatic_thresholds if m not in _NON_SOMATIC_BUILTIN_METRICS + } + if len(standalone_metrics) > 0: + standalone_labels = threshold_metrics_label_units( + metrics=combined_metrics, + thresholds=standalone_metrics, + pass_label="pass", + fail_label="fail", + operator="or", + nan_policy="ignore", + ) + is_non_somatic = is_non_somatic | (standalone_labels["label"] == "fail") + is_non_somatic = is_non_somatic | (standalone_labels["label"] == "fail") + + if split_non_somatic: good_mask = unit_labels["label"] == "good" mua_mask = unit_labels["label"] == "mua" unit_labels.loc[good_mask & is_non_somatic, "label"] = "non_soma_good" diff --git a/src/spikeinterface/curation/bombcell_pipeline.py b/src/spikeinterface/curation/bombcell_pipeline.py new file mode 100644 index 0000000000..f75491d3e1 --- /dev/null +++ b/src/spikeinterface/curation/bombcell_pipeline.py @@ -0,0 +1,505 @@ +""" +BombCell pipeline functions for quality control. + +This module provides wrapper functions for running the BombCell quality +control pipeline on spike-sorted data. + +Functions +--------- +get_default_qc_params + Get default parameters for quality metrics and BombCell labeling. +run_bombcell_qc + Compute quality metrics, run BombCell labeling, and generate plots. + +See Also +-------- +bombcell_get_default_thresholds : Get default classification thresholds. +bombcell_label_units : Core labeling function. +""" + +from __future__ import annotations +import warnings +from pathlib import Path + +# The two refractory-period-violation metrics BombCell supports. The user picks +# one by including it as a key under thresholds["mua"] — that single choice +# determines which metric is computed AND thresholded. No other knob selects it. +_RPV_METRICS = ("sliding_rp_violation", "rp_contamination") + + +def _resolve_rpv_metric(thresholds: dict) -> str: + """Return the single RPV metric name the user selected in thresholds["mua"]. + + Raises if the user specified both (ambiguous) or neither (no RPV check). + """ + mua = thresholds.get("mua", {}) + picked = [m for m in _RPV_METRICS if m in mua] + if len(picked) == 0: + raise ValueError( + f"thresholds['mua'] must include exactly one of {_RPV_METRICS} " + "to select the refractory-period-violation method." + ) + if len(picked) > 1: + raise ValueError( + f"thresholds['mua'] contains multiple RPV metrics ({picked}). " "Pick exactly one: remove the other entry." + ) + return picked[0] + + +def _warn_if_valid_periods_mismatch(sorting_analyzer, thresholds: dict) -> None: + """Warn if a pre-existing valid_unit_periods extension was built with fp/fn + thresholds that disagree with the bombcell RPV / amplitude_cutoff thresholds. + + The point is to surface the mismatch (not fix it) — the user picked those + valid-periods params on purpose, but if their bombcell thresholds have + since drifted, the labels below won't match the periods above. + """ + ext = sorting_analyzer.get_extension("valid_unit_periods") + if ext is None: + return + vp_params = ext.params or {} + mua = thresholds.get("mua", {}) + + # Match the RPV metric that bombcell will threshold + rpv_bc = None + for name in ("sliding_rp_violation", "rp_contamination"): + if name in mua: + rpv_bc = mua[name].get("less", None) + break + fp_vp = vp_params.get("fp_threshold", None) + if rpv_bc is not None and fp_vp is not None and fp_vp != rpv_bc: + warnings.warn( + f"valid_unit_periods was computed with fp_threshold={fp_vp}, but bombcell " + f"RPV threshold is {rpv_bc}. The valid periods were chosen with a different " + "contamination criterion than the one bombcell will use for labeling. " + "Recompute valid_unit_periods if you want them to line up." + ) + + ac_bc = mua.get("amplitude_cutoff", {}).get("less", None) + fn_vp = vp_params.get("fn_threshold", None) + if ac_bc is not None and fn_vp is not None and fn_vp != ac_bc: + warnings.warn( + f"valid_unit_periods was computed with fn_threshold={fn_vp}, but bombcell " + f"amplitude_cutoff threshold is {ac_bc}. The valid periods were chosen with " + "a different missing-spikes criterion than bombcell will use for labeling. " + "Recompute valid_unit_periods if you want them to line up." + ) + + +def get_default_qc_params(): + """ + Get default parameters for quality metrics and BombCell labeling. + + Returns a dictionary that can be modified and passed to run_bombcell_qc(). + + Returns + ------- + dict + Dictionary with the following keys: + + **Quality Metrics Selection** + + compute_amplitude_cutoff : bool, default: False + Whether to compute amplitude_cutoff metric (estimated percentage of + missing spikes). Requires spike_amplitudes extension which is slow + to compute for large recordings. When enabled, spike_amplitudes will + be computed automatically if not already present. + + compute_distance_metrics : bool, default: False + Whether to compute isolation_distance and l_ratio metrics. + These require PCA computation and are slow for large datasets. + Useful for chronic recordings where cluster stability matters. + Not recommended for acute recordings with expected drift. + + compute_drift : bool, default: True + Whether to compute drift metrics (drift_ptp, drift_std, drift_mad). + Measures how much units move over the recording. Important for + acute recordings. drift_ptp (peak-to-peak drift in um) is used + by BombCell MUA thresholds. + + **BombCell Labeling Options** + + Note: the refractory-period-violation method (``sliding_rp_violation`` + or ``rp_contamination``) and its threshold are selected in a single + place — the thresholds dict (``thresholds["mua"]``). Whichever RPV + key the user puts there determines which metric is computed AND + thresholded. Extra metric-specific parameters can be passed via + ``params["metric_params"]`` (e.g. ``{"sliding_rp_violation": + {"exclude_ref_period_below_ms": 0.5}}``). + + split_non_somatic : bool, default: False + If True, split non-somatic units into "non_soma_good" and + "non_soma_mua" based on whether they pass MUA thresholds. + If False, all non-somatic units are labeled "non_soma". + To skip non-somatic labeling entirely, empty or omit the + ``"non-somatic"`` section of ``thresholds``. + + compute_valid_periods : bool, default: False + If True, identify valid time periods per unit (where the unit + has stable amplitude and low refractory violations) and compute + quality metrics only on those periods. Useful for recordings + with unstable periods. Requires amplitude_scalings extension. + If the ``valid_unit_periods`` extension is already present on the + analyzer, it is reused as-is (no recompute); if its ``fp_threshold`` + / ``fn_threshold`` differ from the bombcell RPV / amplitude_cutoff + thresholds, a warning is emitted. To customize valid-periods + parameters, compute the extension upstream with your own settings. + + **Presence Ratio Parameters** + + presence_ratio_bin_duration_s : float, default: 60 + Bin duration in seconds for computing presence ratio. + Presence ratio = fraction of bins containing at least one spike. + 60s bins are standard; shorter bins are stricter. + + **Drift Parameters** + + drift_interval_s : float, default: 60 + Interval in seconds for computing drift. Unit positions are + estimated in each interval and drift is the movement across intervals. + + drift_min_spikes : int, default: 100 + Minimum spikes required per interval to estimate position. + Intervals with fewer spikes are skipped. + + **Plotting Options** + + plot_histograms : bool, default: True + Generate histograms of all metrics with threshold lines. + Saved as "metric_histograms.png". + + plot_waveforms : bool, default: True + Generate waveform overlay plots grouped by label (good, mua, noise, etc.). + Saved as "waveforms_by_label.png". + + plot_upset : bool, default: True + Generate UpSet plots showing which metrics fail together. + Useful for understanding why units are labeled noise/mua. + Requires 'upsetplot' package. Saved as "upset_plot_*.png". + + waveform_ylims : tuple or None, default: (-300, 100) + Y-axis limits for waveform plots in microvolts. + None for automatic scaling. + + figsize_histograms : tuple, default: (15, 10) + Figure size (width, height) in inches for histogram plot. + + Examples + -------- + >>> params = get_default_qc_params() + >>> # Stricter for chronic recordings + >>> params["compute_distance_metrics"] = True + >>> params["compute_drift"] = False # Less relevant for chronic + >>> labels, metrics, figs = run_bombcell_qc(analyzer, params=params) + """ + return { + # Which metrics to compute + "compute_amplitude_cutoff": False, # slow - requires spike_amplitudes + "compute_distance_metrics": False, + "compute_drift": True, + # BombCell labeling options + "split_non_somatic": False, + "compute_valid_periods": False, + # Presence ratio + "presence_ratio_bin_duration_s": 60, + # Drift + "drift_interval_s": 60, + "drift_min_spikes": 100, + # Plotting + "plot_histograms": True, + "plot_waveforms": True, + "plot_upset": True, + "waveform_ylims": (-300, 100), + "figsize_histograms": (15, 10), + } + + +def run_bombcell_qc( + sorting_analyzer, + output_folder: str | Path = "bombcell", + params: dict | str | Path | None = None, + thresholds: dict | str | Path | None = None, + rerun_quality_metrics: bool = False, + rerun_pca: bool = False, + rerun_amplitude_scalings: bool = False, + n_jobs: int = -1, + progress_bar: bool = True, +): + """ + Compute quality metrics and run BombCell unit labeling. + + This function computes quality metrics on the SortingAnalyzer, runs the + BombCell labeling algorithm to classify units as good/mua/noise/non_soma, + generates diagnostic plots, and saves results. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + Analyzer with template_metrics extension computed. + output_folder : str or Path, default: "bombcell" + Folder to save results (CSV files and plots). Set to None to skip saving. + Created if it doesn't exist. + params : dict, str, Path, or None, default: None + QC parameters from get_default_qc_params(), or a path to a JSON file + containing such a dict. If None, uses defaults. + + To override the default metric list built from the compute_* flags, set + params["metric_names"] to an explicit list of metric names. To override + the default metric params, set params["metric_params"] to a dict mapping + metric name -> param dict. Any metric you add this way must correspond to + a valid SpikeInterface quality metric. + thresholds : dict, str, Path, or None, default: None + BombCell classification thresholds from bombcell_get_default_thresholds(), + or a path to a JSON file containing such a dict. If None, uses defaults. + Structure: + + - "noise": Thresholds for waveform quality. Failing ANY -> "noise". + - "mua": Thresholds for spike quality. Failing ANY -> "mua". + - "non-somatic": Thresholds for waveform shape. Determines non-somatic units. + + Each threshold is {"greater": value, "less": value}. Use None to disable. + See bombcell_get_default_thresholds() docstring for all thresholds. + + Refractory-period violations: ``thresholds["mua"]`` must contain exactly + ONE of ``"sliding_rp_violation"`` or ``"rp_contamination"``. That single + entry is the ONE place where the RPV method is selected — it determines + both which RPV metric is computed AND the threshold applied to it. + + rerun_quality_metrics : bool, default: False + Force recomputation of quality metrics even if they exist. + rerun_pca : bool, default: False + Force recomputation of PCA (only relevant if compute_distance_metrics=True). + rerun_amplitude_scalings : bool, default: False + Force recomputation of amplitude_scalings (used as a prerequisite for + amplitude_cutoff and valid periods). + n_jobs : int, default: -1 + Number of parallel jobs. + progress_bar : bool, default: True + Show progress bars. + + Returns + ------- + labels : pd.DataFrame + DataFrame with unit_ids as index and "bombcell_label" column. + Possible labels: "good", "mua", "noise", "non_soma" + (or "non_soma_good"/"non_soma_mua" if split_non_somatic=True). + metrics : pd.DataFrame + Combined DataFrame of all quality metrics and template metrics. + Index is unit_ids, columns are metric names. + figures : dict + Dictionary of matplotlib figures: + - "histograms": Metric histograms with threshold lines. + - "waveforms": Waveform overlays grouped by label. + - "upset": List of UpSet plot figures (one per label type). + + Saved Files (in output_folder) + ------------------------------ + - labeling_results_wide.csv: One row per unit with all metrics and label. + - labeling_results_narrow.csv: One row per unit-metric with pass/fail status. + - thresholds.json: Thresholds used for classification (reproducibility). + - bombcell_config.json: bombcell-specific options (split_non_somatic, + compute_valid_periods). + Note: quality metric params are stored on the analyzer via the quality_metrics extension. + - metric_histograms.png: Histogram of each metric with threshold lines. + - waveforms_by_label.png: Waveform overlays for each label category. + - upset_plot_*.png: UpSet plots showing metric failure combinations. + + Note on valid periods + --------------------- + When ``params["compute_valid_periods"]=True``, the valid time periods per unit + are stored by the ``valid_unit_periods`` extension on the analyzer itself. + Access them via ``sorting_analyzer.get_extension("valid_unit_periods").get_data()``. + If you want to customize the ``valid_unit_periods`` parameters (``fp_threshold``, + ``fn_threshold``, ``period_mode``, etc.), compute the extension upstream and + the pipeline will reuse it. A warning is raised if its fp/fn differ from the + bombcell RPV / amplitude_cutoff thresholds. + + Examples + -------- + Basic usage with defaults: + + >>> labels, metrics, figs = run_bombcell_qc(analyzer) + + With custom parameters and thresholds: + + >>> params = get_default_qc_params() + >>> params["compute_distance_metrics"] = True # For chronic recordings + >>> + >>> thresholds = bombcell_get_default_thresholds() + >>> thresholds["mua"]["sliding_rp_violation"]["less"] = 0.05 # Stricter RP violations + >>> thresholds["mua"]["num_spikes"]["greater"] = 100 # Lower spike threshold + >>> + >>> labels, metrics, figs = run_bombcell_qc( + ... analyzer, + ... output_folder="qc_results", + ... params=params, + ... thresholds=thresholds, + ... ) + + Get good units for downstream analysis: + + >>> good_units = labels[labels["bombcell_label"] == "good"].index.tolist() + >>> mua_units = labels[labels["bombcell_label"] == "mua"].index.tolist() + """ + import json + + from .bombcell_curation import ( + bombcell_get_default_thresholds, + bombcell_label_units, + save_bombcell_results, + ) + + # Resolve params (dict, JSON path, or None) + if params is None: + params = get_default_qc_params() + elif isinstance(params, (str, Path)): + with open(params, "r") as f: + params = json.load(f) + + # Resolve thresholds (dict, JSON path, or None) + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + elif isinstance(thresholds, (str, Path)): + with open(thresholds, "r") as f: + thresholds = json.load(f) + + job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=progress_bar) + + if output_folder is not None: + output_folder = Path(output_folder) + output_folder.mkdir(parents=True, exist_ok=True) + + # The RPV method is selected by whichever key the user puts in thresholds["mua"] + # (either "sliding_rp_violation" or "rp_contamination"). This is the ONE place + # the choice lives — no separate pipeline-level parameter. + rpv_metric = _resolve_rpv_metric(thresholds) + + qm_params = { + "presence_ratio": {"bin_duration_s": params["presence_ratio_bin_duration_s"]}, + "drift": { + "interval_s": params["drift_interval_s"], + "min_spikes_per_interval": params["drift_min_spikes"], + }, + } + + # Build metric names (user can override via params["metric_names"]) + if "metric_names" in params and params["metric_names"] is not None: + metric_names = list(params["metric_names"]) + else: + metric_names = ["amplitude_median", "snr", "num_spikes", "presence_ratio", "firing_rate"] + + if params["compute_amplitude_cutoff"]: + metric_names.append("amplitude_cutoff") + + metric_names.append(rpv_metric) + + if params["compute_drift"]: + metric_names.append("drift") + + if params["compute_distance_metrics"]: + metric_names.append("mahalanobis") + + compute_valid_periods = params["compute_valid_periods"] + + # Ensure prerequisite extensions are computed for whichever metrics are requested + # amplitude_cutoff uses amplitude_scalings (not spike_amplitudes) in this pipeline + needs_amplitude_scalings = "amplitude_cutoff" in metric_names or compute_valid_periods + if needs_amplitude_scalings: + if rerun_amplitude_scalings or not sorting_analyzer.has_extension("amplitude_scalings"): + sorting_analyzer.compute("amplitude_scalings", **job_kwargs) + + if "mahalanobis" in metric_names: + if not sorting_analyzer.has_extension("principal_components") or rerun_pca: + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + + # User-provided metric_params override the defaults built above + if "metric_params" in params and params["metric_params"] is not None: + for metric, mp in params["metric_params"].items(): + qm_params[metric] = mp + + # Valid unit periods (pipeline-level, since it affects which QMs bombcell sees). + # - If the extension already exists, reuse it as-is, but warn if its fp/fn thresholds + # don't line up with the bombcell RPV / amplitude_cutoff thresholds. + # - If it doesn't exist, compute it using defaults (users who want custom fp/fn + # or period params should compute it upstream themselves). + if compute_valid_periods: + if sorting_analyzer.has_extension("valid_unit_periods"): + _warn_if_valid_periods_mismatch(sorting_analyzer, thresholds) + else: + sorting_analyzer.compute("valid_unit_periods", **job_kwargs) + + # Compute quality metrics + if not sorting_analyzer.has_extension("quality_metrics") or rerun_quality_metrics: + sorting_analyzer.compute( + "quality_metrics", + metric_names=metric_names, + metric_params=qm_params, + use_valid_periods=compute_valid_periods, + **job_kwargs, + ) + + # Run BombCell (pure labeler — no extension mutation, no valid-periods magic) + labels = bombcell_label_units( + sorting_analyzer=sorting_analyzer, + thresholds=thresholds, + split_non_somatic=params["split_non_somatic"], + ) + + metrics = sorting_analyzer.get_metrics_extension_data() + + # Plots + figures = {} + if params["plot_histograms"] or params["plot_waveforms"] or params["plot_upset"]: + import spikeinterface.widgets as sw + + if params["plot_histograms"]: + w = sw.plot_metric_histograms(sorting_analyzer, thresholds, figsize=params["figsize_histograms"]) + figures["histograms"] = w.figure + + if params["plot_waveforms"]: + w = sw.plot_unit_labels(sorting_analyzer, labels["bombcell_label"], ylims=params["waveform_ylims"]) + figures["waveforms"] = w.figure + + if params["plot_upset"]: + w = sw.plot_bombcell_labels_upset( + sorting_analyzer, + unit_labels=labels["bombcell_label"], + thresholds=thresholds, + unit_labels_to_plot=["noise", "mua"], + ) + figures["upset"] = w.figures + + # Save + if output_folder is not None: + import json + + save_bombcell_results( + metrics=metrics, + unit_label=labels["bombcell_label"].values, + thresholds=thresholds, + folder=output_folder, + ) + # Save thresholds and bombcell-specific config so the run is reproducible + # (quality metric params are stored on the analyzer itself via the extension) + with open(output_folder / "thresholds.json", "w") as f: + json.dump(thresholds, f, indent=2) + bombcell_config = { + "split_non_somatic": params["split_non_somatic"], + "compute_valid_periods": compute_valid_periods, + } + with open(output_folder / "bombcell_config.json", "w") as f: + json.dump(bombcell_config, f, indent=2) + # Note: valid periods are stored by the valid_unit_periods extension on the analyzer itself. + # Access them via: analyzer.get_extension("valid_unit_periods").get_data() + if "histograms" in figures: + figures["histograms"].savefig(output_folder / "metric_histograms.png", dpi=150, bbox_inches="tight") + if "waveforms" in figures: + figures["waveforms"].savefig(output_folder / "waveforms_by_label.png", dpi=150, bbox_inches="tight") + if "upset" in figures: + for i, fig in enumerate(figures["upset"]): + fig.savefig(output_folder / f"upset_plot_{i}.png", dpi=150, bbox_inches="tight") + + print(f"Labeled {len(labels)} units:") + print(labels["bombcell_label"].value_counts().to_string()) + + return labels, metrics, figures