diff --git a/pyproject.toml b/pyproject.toml index 6280068..2fdb158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ classifiers = [ ] dependencies = [ - "spikeinterface[full]>=0.104.0", + # "spikeinterface[full]>=0.104.0", + "spikeinterface @ git+https://github.com/SpikeInterface/spikeinterface.git", "markdown" ] diff --git a/spikeinterface_gui/backend_panel.py b/spikeinterface_gui/backend_panel.py index 19c8a2b..9abeaa6 100644 --- a/spikeinterface_gui/backend_panel.py +++ b/spikeinterface_gui/backend_panel.py @@ -134,8 +134,7 @@ def on_active_view_updated(self, param): view._panel_view_is_active = False def on_unit_color_changed(self, param): - if not self._active: - return + # In this case we send it also if the view is not active, because we want to update colors anyways for view in self.controller.views: if param.obj.view == view: continue diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 6f3f60c..f05a8ce 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -382,8 +382,8 @@ def __init__( curation_data = json.load(f) elif self.analyzer.format == "zarr": - import zarr - zarr_root = zarr.open(self.analyzer.folder, mode='r') + from spikeinterface.core.zarrextractors import super_zarr_open + zarr_root = super_zarr_open(self.analyzer.folder, mode='r') if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys(): curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"] @@ -548,6 +548,26 @@ def get_information_txt(self): return txt + def get_divergent_unit_colors(self, colormap="tab10"): + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap + + unit_locations = self.analyzer.get_extension("unit_locations").get_data() + cmap = plt.get_cmap(colormap) + if not isinstance(cmap, ListedColormap): + raise ValueError(f"Colormap {colormap} is not a qualitative colormap") + num_entries = len(cmap.colors) + # lexsort by x and y + sorted_inds = np.lexsort((unit_locations[:, 0], unit_locations[:, 1])) + # now assign colors with sequentially to sorted units + colors = {} + for i, unit_ind in enumerate(sorted_inds): + unit_id = self.unit_ids[unit_ind] + # Assign cmap color *and* alpha value to the colors dict + colors[unit_id] = cmap.colors[i % num_entries] + (1,) + return colors + + def refresh_colors(self): if self.backend == "qt": self._cached_qcolors = {} @@ -555,15 +575,13 @@ def refresh_colors(self): pass if self.main_settings['color_mode'] == 'color_by_unit': - self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', - shuffle=True, seed=42) + self.colors = self.get_divergent_unit_colors(colormap="tab10") elif self.main_settings['color_mode'] == 'color_only_visible': - unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', - shuffle=True, seed=42) + unit_colors = self.get_divergent_unit_colors(colormap="tab10") self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids} for unit_id in self.get_visible_unit_ids(): self.colors[unit_id] = unit_colors[unit_id] - elif self.main_settings['color_mode'] == 'color_by_visibility': + elif self.main_settings['color_mode'] == 'color_by_visibility': self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids} import matplotlib.pyplot as plt cmap = plt.colormaps['tab10'] diff --git a/spikeinterface_gui/correlogramview.py b/spikeinterface_gui/correlogramview.py index 9ca6fa6..1f9d1ee 100644 --- a/spikeinterface_gui/correlogramview.py +++ b/spikeinterface_gui/correlogramview.py @@ -34,6 +34,11 @@ def _compute(self): # clear cache self.figure_cache = {} + def on_unit_color_changed(self): + # clear cache + self.figure_cache = {} + self.refresh() + ## Qt ## def _qt_make_layout(self): @@ -73,6 +78,15 @@ def _qt_refresh(self): unit_id2 = visible_unit_ids[c] if (unit_id1, unit_id2) in self.figure_cache: plot = self.figure_cache[(unit_id1, unit_id2)] + if self.controller.main_settings["color_mode"] == 'color_by_visibility': + # Update color in cached figure + if r == c: + unit_id = visible_unit_ids[r] + color = colors[unit_id] + for item in plot.items: + if hasattr(item, 'setBrush') and hasattr(item, 'setPen'): + item.setBrush(color) + item.setPen(color) else: # create new plot i = unit_ids.index(visible_unit_ids[r]) @@ -145,6 +159,16 @@ def _panel_refresh(self): if (unit1, unit2) in self.figure_cache: fig = self.figure_cache[(unit1, unit2)] + # for the color_by_visibility + if self.controller.main_settings["color_mode"] == 'color_by_visibility': + # Update color in cached figure + if r == c: + unit_id = visible_unit_ids[r] + color = colors[unit_id] + for renderer in fig.renderers: + if hasattr(renderer, 'glyph') and hasattr(renderer.glyph, 'fill_color'): + renderer.glyph.fill_color = color + renderer.glyph.line_color = color else: # create new figure i = unit_ids.index(unit1) diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index 79a2638..e3af479 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -40,7 +40,6 @@ def on_max_visible_units_changed(self): self.notify_unit_visibility_changed() def on_change_color_mode(self): - self.controller.main_settings['color_mode'] = self.main_settings['color_mode'] self.controller.refresh_colors() self.notify_unit_color_changed() diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 66712ea..66c98d1 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -48,6 +48,7 @@ class MergeView(ViewBase): def __init__(self, controller=None, parent=None, backend="qt"): ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) self.include_deleted = False + self.exclude_noise = True def compute_potential_merges(self): preset = self.preset @@ -82,14 +83,17 @@ def compute_potential_merges(self): f"({len(potential_merges)} after filtering deleted units).") def get_potential_merges(self): - # return the potential merges, considering the include deleted option - unit_ids = list(self.controller.unit_ids) + # return the potential merges, considering the include deleted and exclude noise options proposed_merge_unit_groups = [] for group_ids in self.proposed_merge_unit_groups_all: - if not self.include_deleted and self.controller.curation: - deleted_unit_ids = self.controller.curation_data["removed"] - if any(unit_id in deleted_unit_ids for unit_id in group_ids): - continue + if self.controller.curation: + if not self.include_deleted: + deleted_unit_ids = self.controller.curation_data["removed"] + if any(unit_id in deleted_unit_ids for unit_id in group_ids): + continue + if self.exclude_noise: + if any(self.controller.get_unit_label(unit_id, "quality") == "noise" for unit_id in group_ids): + continue proposed_merge_unit_groups.append(group_ids) return proposed_merge_unit_groups @@ -158,7 +162,6 @@ def accept_group_merge(self, group_ids): ) return self.notify_manual_curation_updated() - self.refresh() ### QT def _qt_get_selected_group_ids(self): @@ -205,6 +208,10 @@ def _qt_on_include_deleted_change(self): self.include_deleted = self.include_deleted_checkbox.isChecked() self.refresh() + def _qt_on_exclude_noise_change(self): + self.exclude_noise = self.exclude_noise_checkbox.isChecked() + self.refresh() + def _qt_make_layout(self): from .myqt import QT import pyqtgraph as pg @@ -246,10 +253,19 @@ def _qt_make_layout(self): row_layout.addWidget(but) if self.controller.curation: + checkbox_layout = QT.QVBoxLayout() + self.include_deleted_checkbox = QT.QCheckBox("Include deleted units") self.include_deleted_checkbox.setChecked(False) self.include_deleted_checkbox.stateChanged.connect(self._qt_on_include_deleted_change) - row_layout.addWidget(self.include_deleted_checkbox) + checkbox_layout.addWidget(self.include_deleted_checkbox) + + self.exclude_noise_checkbox = QT.QCheckBox("Exclude noise units") + self.exclude_noise_checkbox.setChecked(True) + self.exclude_noise_checkbox.stateChanged.connect(self._qt_on_exclude_noise_change) + checkbox_layout.addWidget(self.exclude_noise_checkbox) + + row_layout.addLayout(checkbox_layout) self.layout.addLayout(row_layout) @@ -353,11 +369,22 @@ def _panel_make_layout(self): name=f"{preset.capitalize()} parameters") self.preset = list(self.preset_params.keys())[0] - # shortcuts + # group the preset selector and its parameters into a collapsible accordion, + # so it can be hidden after merges are computed + self.preset_settings_column = pn.Column( + self.preset_selector, + self.preset_params_selectors[self.preset], + sizing_mode="stretch_width", + ) + self.preset_accordion = pn.Accordion( + ("Preset & parameters", self.preset_settings_column), + active=[0], + sizing_mode="stretch_width", + ) + + # shortcuts (row navigation is handled by SelectableTabulator; only accept here) shortcuts = [ KeyboardShortcut(name="accept", key="a", ctrlKey=True), - KeyboardShortcut(name="next", key="ArrowDown", ctrlKey=False), - KeyboardShortcut(name="previous", key="ArrowUp", ctrlKey=False), ] shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) shortcuts_component.on_msg(self._panel_handle_shortcut) @@ -374,13 +401,15 @@ def _panel_make_layout(self): if self.controller.curation: self.include_deleted = pn.widgets.Checkbox(name="Include deleted units", value=False) self.include_deleted.param.watch(self._panel_include_deleted_change, "value") - calculate_list.append(self.include_deleted) + + self.exclude_noise_widget = pn.widgets.Checkbox(name="Exclude noise units", value=True) + self.exclude_noise_widget.param.watch(self._panel_exclude_noise_change, "value") + + calculate_list.append(pn.Column(self.include_deleted, self.exclude_noise_widget)) calculate_row = pn.Row(*calculate_list, sizing_mode="stretch_width") self.layout = pn.Column( - # add params - self.preset_selector, - self.preset_params_selectors[self.preset], + self.preset_accordion, calculate_row, self.table_area, shortcuts_component, @@ -394,24 +423,33 @@ def _panel_refresh(self): import pandas as pd import panel as pn import matplotlib.colors as mcolors - from .utils_panel import unit_formatter + from .utils_panel import unit_formatter, SelectableTabulator pn.extension("tabulator") # Create table labels, rows = self.get_table_data() + + if not rows: + self.table = None + self.table_area.update("No merges computed yet.") + return + # set unmutable data data = {label: [] for label in labels} for row in rows: for label in labels: if label.startswith("unit_id"): unit_id = row[label] - data[label].append({"id": unit_id, "color": mcolors.to_hex(self.controller.get_unit_color(unit_id))}) + n = self.controller.num_spikes[unit_id] + data[label].append({"id": unit_id, "color": mcolors.to_hex(self.controller.get_unit_color(unit_id)), "n": n}) else: data[label].append(row[label]) df = pd.DataFrame(data=data) formatters = {label: unit_formatter for label in labels if label.startswith("unit_id")} - self.table = pn.widgets.Tabulator( + skip_sort_columns = [label for label in labels if label.startswith("unit_id")] + skip_sort_columns.append("group_ids") + self.table = SelectableTabulator( df, formatters=formatters, height=400, @@ -420,40 +458,40 @@ def _panel_refresh(self): hidden_columns=["group_ids"], disabled=True, selectable=1, - sortable=False + sortable=True, + skip_sort_columns=skip_sort_columns, + # SelectableTabulator functions + parent_view=self, + conditional_shortcut=self.is_view_active, + on_selection_changed=self._panel_on_selection_changed, ) - - # Add click handler with double click detection - self.table.on_click(self._panel_on_click) self.table_area.update(self.table) def _panel_compute_merges(self, event): self._compute_merges() + # collapse the preset accordion once merges have been computed + if self.table is not None: + self.preset_accordion.active = [] def _panel_on_preset_change(self, event): self.preset = event.new - if self.is_warning_active(): - layout_index = 2 - else: - layout_index = 1 - self.layout[layout_index] = self.preset_params_selectors[self.preset] - - def _panel_on_click(self, event): - import panel as pn + self.preset_settings_column[1] = self.preset_params_selectors[self.preset] - # set unit visibility - row = event.row - - def _do_update(): - self.table.selection = [row] - - pn.state.execute(_do_update, schedule=True) - self._panel_update_visible_pair(row) + def _panel_on_selection_changed(self): + # called by SelectableTabulator whenever the selection changes (click or keyboard) + selected = self.table.selection + if len(selected) == 0: + return + self._panel_update_visible_pair(selected[0]) def _panel_include_deleted_change(self, event): self.include_deleted = event.new self.refresh() + def _panel_exclude_noise_change(self, event): + self.exclude_noise = event.new + self.refresh() + def _panel_update_visible_pair(self, row): table_row = self.table.value.iloc[row] visible_unit_ids = [] @@ -465,41 +503,23 @@ def _panel_update_visible_pair(self, row): self.notify_unit_visibility_changed() def _panel_handle_shortcut(self, event): - import panel as pn - - if event.data == "accept": - selected = self.table.selection - if len(selected) == 0: - return - # selected is always 1 - row = selected[0] - group_ids = self.table.value.iloc[row].group_ids - self.accept_group_merge(group_ids) - self.notify_manual_curation_updated() - - next_row = min(row + 1, len(self.table.value) - 1) - - def _select_next(): - self.table.selection = [next_row] - - pn.state.execute(_select_next, schedule=True) - self._panel_update_visible_pair(next_row) - elif event.data == "next": - next_row = min(self.table.selection[0] + 1, len(self.table.value) - 1) - - def _do_next(): - self.table.selection = [next_row] - - pn.state.execute(_do_next, schedule=True) - self._panel_update_visible_pair(next_row) - elif event.data == "previous": - previous_row = max(self.table.selection[0] - 1, 0) - - def _do_prev(): - self.table.selection = [previous_row] + if event.data != "accept": + return + if not self.is_view_active(): + return + if self.table is None: + return + selected = self.table.selection + if len(selected) == 0: + return + # selected is always 1 + row = selected[0] + group_ids = self.table.value.iloc[row].group_ids + self.accept_group_merge(group_ids) - pn.state.execute(_do_prev, schedule=True) - self._panel_update_visible_pair(previous_row) + # advance to the next row; the selection setter triggers _panel_on_selection_changed + next_row = min(row + 1, len(self.table.value) - 1) + self.table.selection = [next_row] def _panel_on_spike_selection_changed(self): pass diff --git a/spikeinterface_gui/ndscatterview.py b/spikeinterface_gui/ndscatterview.py index e5a3ad4..77864c3 100644 --- a/spikeinterface_gui/ndscatterview.py +++ b/spikeinterface_gui/ndscatterview.py @@ -3,6 +3,7 @@ """ import itertools +import warnings import numpy as np from matplotlib.path import Path as mpl_path @@ -57,6 +58,7 @@ def __init__(self, controller=None, parent=None, backend="qt"): self.auto_update_limit = True self._lasso_vertices = [] self._current_selected = 0 + self._best_mode = False ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) @@ -85,6 +87,7 @@ def new_tour_step(self): self._refresh(update_colors=False, update_components=False) def next_face(self): + self._best_mode = False self.n_face += 1 self.n_face = self.n_face%len(self.hyper_faces) ndim = self.data.shape[1] @@ -105,17 +108,100 @@ def get_one_random_projection(self): return projection def random_projection(self): + self._best_mode = False self.update_selected_components() self.projection = self.get_one_random_projection() self.tour_step = 0 # here we don't want to update the components because it's been done already! self.refresh(update_components=False) + def best_projection(self): + # Note: for best projection we don't restrict to the visible channels, because we want to find the + # best projection over all channels. We therefore mark all components as active so that the + # full projection is rendered (see apply_dot, which slices by selected_comp). + self.selected_comp[:] = True + + X_list, y_list = [], [] + for unit_ind, unit_id in self.controller.iter_visible_units(): + mask = np.flatnonzero(self.pc_unit_index == unit_ind) + if len(mask) > 0: + X_list.append(self.data[mask, :]) + y_list.extend([unit_ind] * len(mask)) + + if len(X_list) == 0: + return + + X = np.vstack(X_list) + y = np.array(y_list) + n_classes = len(np.unique(y)) + + ndim = self.data.shape[1] + projection = None + + # For LDA, we need at least 2 samples per class and at least 2 classes. + # For PCA, we need at least 2 samples and at least 1 feature. + if X.shape[1] == 0 or X.shape[0] < 2: + pass + else: + try: + projection = np.zeros((ndim, 2)) + if n_classes >= 2: + # multiple units: maximize class separation with LDA + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + + n_proj = min(2, n_classes - 1) + lda = LinearDiscriminantAnalysis(n_components=n_proj) + lda.fit(X, y) + directions = lda.scalings_[:, :n_proj] + else: + # single unit: pick the directions with highest variance explained (PCA) + from sklearn.decomposition import PCA + + n_proj = min(2, X.shape[1]) + pca = PCA(n_components=n_proj) + pca.fit(X) + directions = pca.components_[:n_proj].T + + projection[:, :n_proj] = directions + + for j in range(n_proj): + norm = np.linalg.norm(projection[:, j]) + if norm > 0: + projection[:, j] /= norm + + # only 1 projection direction: fill second axis with random orthogonal direction + if n_proj < 2: + main_dir = projection[:, 0] + rand_dir = np.random.randn(ndim) + rand_dir -= np.dot(rand_dir, main_dir) * main_dir + norm = np.linalg.norm(rand_dir) + if norm > 0: + rand_dir /= norm + projection[:, 1] = rand_dir + except (ValueError, IndexError) as e: + # degenerate data (e.g. zero-rank / constant features): keep current projection + warnings.warn(f"NDScatter best_projection failed: {e}. Keeping current projection.") + projection = None + + self._best_mode = True + # if computing the best projection failed, keep the current one but still + # refresh so the newly selected units are shown + if projection is not None: + self.projection = projection + self.tour_step = 0 + self.refresh(update_components=False) + def on_unit_visibility_changed(self): - self.random_projection() - + if self._best_mode: + self.best_projection() + else: + self.random_projection() + def on_channel_visibility_changed(self): - self.random_projection() + if self._best_mode: + self.best_projection() + else: + self.random_projection() def apply_dot(self, data): projected = np.dot(data[:, self.selected_comp], self.projection[self.selected_comp, :]) @@ -172,6 +258,10 @@ def get_plotting_data(self, return_spike_indices=False): def update_selected_components(self): + # in best projection mode all components are used (see best_projection) + if self._best_mode: + self.selected_comp[:] = True + return n_pc_per_chan = self.pc_data.shape[1] n = min(self.settings['num_pc_per_channel'], n_pc_per_chan) self.selected_comp[:] = False @@ -182,24 +272,34 @@ def update_selected_components(self): def _qt_make_layout(self): from .myqt import QT import pyqtgraph as pg - from .utils_qt import ViewBoxHandlingLassoAndGain, add_stretch_to_qtoolbar + from .utils_qt import ViewBoxHandlingLassoAndGain, add_stretch_to_qtoolbar, qt_style + self.layout = QT.QVBoxLayout() - self.layout = QT.QHBoxLayout() + # Row 2 toolbar (Row 1 is the ViewWidget toolbar with settings/refresh/?) + tb = QT.QToolBar() + tb.setStyleSheet(qt_style) + self.layout.addWidget(tb) - # toolbar - tb = self.qt_widget.view_toolbar - but = QT.QPushButton('Random') - tb.addWidget(but) - but.clicked.connect(self.random_projection) - but = QT.QPushButton('Random tour', checkable = True) + but = QT.QPushButton('Next Face') tb.addWidget(but) - but.clicked.connect(self._qt_start_stop_tour) + but.clicked.connect(self._qt_on_next_face) + + self._qt_best_but = QT.QPushButton('Best') + self._qt_best_but.setCheckable(True) + tb.addWidget(self._qt_best_but) + self._qt_best_but.clicked.connect(self._qt_on_best_projection) - but = QT.QPushButton('next face') + but = QT.QPushButton('Random') tb.addWidget(but) - but.clicked.connect(self.next_face) + but.clicked.connect(self._qt_on_random_projection) + + self._qt_random_tour_but = QT.QPushButton('Random Tour') + self._qt_random_tour_but.setCheckable(True) + tb.addWidget(self._qt_random_tour_but) + self._qt_random_tour_but.clicked.connect(self._qt_start_stop_tour) + add_stretch_to_qtoolbar(tb) self.graphicsview = pg.GraphicsView() self.layout.addWidget(self.graphicsview) @@ -341,9 +441,26 @@ def _qt_refresh(self, update_components=True, update_colors=True): def _qt_on_spike_selection_changed(self): self.refresh() - + + def _qt_on_next_face(self): + self._qt_best_but.setChecked(False) + self.next_face() + + def _qt_on_random_projection(self): + self._qt_best_but.setChecked(False) + self.random_projection() + + def _qt_on_best_projection(self, checked): + if checked: + if self._qt_random_tour_but.isChecked(): + self._qt_random_tour_but.setChecked(False) + self._qt_start_stop_tour(False) + if not self._best_mode: + self.best_projection() + def _qt_start_stop_tour(self, checked): if checked: + self._qt_best_but.setChecked(False) self.tour_step = 0 self.timer_tour.setInterval(int(self.settings['refresh_interval'])) self.timer_tour.start() @@ -435,7 +552,10 @@ def _panel_make_layout(self): self.random_button = pn.widgets.Button(name="Random", button_type="default", width=100) self.random_button.on_click(self._panel_random_projection) - self.random_tour_button = pn.widgets.Toggle(name="Random Tour", button_type="default", width=100) + self.best_button = pn.widgets.Toggle(name="Best", button_type="default", width=100) + self.best_button.param.watch(self._panel_best_projection, "value") + + self.random_tour_button = pn.widgets.Toggle(name="Tour", button_type="default", width=100) self.random_tour_button.param.watch(self._panel_start_stop_tour, "value") self.select_toggle_button = pn.widgets.Toggle(name="Select") @@ -443,10 +563,23 @@ def _panel_make_layout(self): self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry) - self.toolbar = pn.Row( - self.next_face_button, self.random_button, self.random_tour_button, self.select_toggle_button, + first_row = pn.Row( + self.next_face_button, + self.select_toggle_button, + sizing_mode="stretch_both", + ) + second_row = pn.Row( + self.best_button, + self.random_button, + self.random_tour_button, + sizing_mode="stretch_both", + ) + + self.toolbar = pn.Column( + first_row, + second_row, sizing_mode="stretch_both", - styles={"flex": "0.15"} + styles={"flex": "0.25"} ) self.layout = pn.Column( @@ -459,6 +592,8 @@ def _panel_make_layout(self): self.tour_timer = None def _panel_refresh(self, update_components=True, update_colors=True): + import panel as pn + if update_components: self.update_selected_components() scatter_x, scatter_y, _, _, spike_indices = self.get_plotting_data(return_spike_indices=True) @@ -475,17 +610,22 @@ def _panel_refresh(self, update_components=True, update_colors=True): if not update_colors: colors = self.scatter_source.data.get("color") - self.scatter_source.data = { + data = { "x": xs, "y": ys, "color": colors, "spike_indices": plotted_spike_indices } + limit = self.limit + + def _do_update(): + self.scatter_source.data = data + self.scatter_fig.x_range.start = -limit + self.scatter_fig.x_range.end = limit + self.scatter_fig.y_range.start = -limit + self.scatter_fig.y_range.end = limit - self.scatter_fig.x_range.start = -self.limit - self.scatter_fig.x_range.end = self.limit - self.scatter_fig.y_range.start = -self.limit - self.scatter_fig.y_range.end = self.limit + pn.state.execute(_do_update, schedule=True) def _panel_on_spike_selection_changed(self): import panel as pn @@ -515,14 +655,24 @@ def _do_update(): pn.state.execute(_do_update, schedule=True) def _panel_next_face(self, event): + self.best_button.value = False self.next_face() def _panel_random_projection(self, event): + self.best_button.value = False self.random_projection() + def _panel_best_projection(self, event): + if event.new: + if self.random_tour_button.value: + self.random_tour_button.value = False + if not self._best_mode: + self.best_projection() + def _panel_start_stop_tour(self, event): import panel as pn if event.new: + self.best_button.value = False self.tour_step = 0 self.tour_timer = pn.state.add_periodic_callback(self.new_tour_step, period=self.settings['refresh_interval']) self.auto_update_limit = False diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py index d1eb4dd..82b3459 100644 --- a/spikeinterface_gui/probeview.py +++ b/spikeinterface_gui/probeview.py @@ -520,10 +520,20 @@ def _panel_refresh(self): # Pre-compute circle updates circle_update = None - if len(selected_unit_indices) == 1: + cx, cy = None, None + n = len(selected_unit_indices) + if n == 1: + # always refresh the channel ROI unit_index = selected_unit_indices[0] - unit_positions = self.controller.unit_positions - cx, cy = unit_positions[unit_index, 0], unit_positions[unit_index, 1] + cx, cy = self.controller.unit_positions[unit_index, :] + elif n > 1: + # change ROI only if all units are inside the radius + positions = self.controller.unit_positions[selected_unit_indices, :] + distances = np.linalg.norm(positions[:, np.newaxis] - positions[np.newaxis, :], axis=2) + if np.max(distances) < (self.settings['radius_units'] * 2): + cx, cy = np.mean(positions, axis=0) + + if cx is not None: visible_channel_inds = self.update_channel_visibility(cx, cy, radius_channel) circle_update = (cx, cy, visible_channel_inds) diff --git a/spikeinterface_gui/tests/test_mainwindow_panel.py b/spikeinterface_gui/tests/test_mainwindow_panel.py index 3971078..05d5e19 100644 --- a/spikeinterface_gui/tests/test_mainwindow_panel.py +++ b/spikeinterface_gui/tests/test_mainwindow_panel.py @@ -10,22 +10,11 @@ import numpy as np -# import logging - - -# logger = logging.getLogger('bokeh') -# logger.setLevel(logging.DEBUG) -# logging.basicConfig(level=logging.DEBUG) - - -# test_folder = Path(__file__).parent / 'my_dataset_small' -test_folder = Path(__file__).parent / 'my_dataset_big' -# test_folder = Path(__file__).parent / 'my_dataset_multiprobe' - def setup_module(): + global test_folder case = test_folder.stem.split('_')[-1] - make_analyzer_folder(test_folder, case=case) + make_analyzer_folder(test_folder, case=case, unit_dtype="int") def teardown_module(): clean_all(test_folder) @@ -127,14 +116,10 @@ def test_launcher(verbose=True): if __name__ == '__main__': args = parser.parse_args() dataset = args.dataset - if dataset == "small": - test_folder = Path(__file__).parent / 'my_dataset_small' - elif dataset == "big": - test_folder = Path(__file__).parent / 'my_dataset_big' - elif dataset == "multiprobe": - test_folder = Path(__file__).parent / 'my_dataset_multiprobe' - else: - test_folder = Path(dataset) + global test_folder + if dataset is not None: + test_folder = Path(__file__).parents[2] / f"my_dataset_{dataset}" + if not test_folder.is_dir(): setup_module() diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index 3349eba..0fbc3ef 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -29,7 +29,7 @@ def setup_module(): global test_folder case = test_folder.stem.split('_')[-1] - make_analyzer_folder(test_folder, case=case) + make_analyzer_folder(test_folder, case=case, unit_dtype="int") def teardown_module(): clean_all(test_folder) diff --git a/spikeinterface_gui/tests/testingtools.py b/spikeinterface_gui/tests/testingtools.py index dcd87af..113451a 100644 --- a/spikeinterface_gui/tests/testingtools.py +++ b/spikeinterface_gui/tests/testingtools.py @@ -37,7 +37,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): num_channels = 32 num_units = 16 else: - raise ValueError() + raise ValueError(f"Wrong dataset type {case}") diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index b34f354..995c529 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -273,7 +273,6 @@ def _panel_make_layout(self): ) def _panel_refresh(self): - self._panel_remove_event_line() t, segment_index = self.controller.get_time() xsize = self.xsize t1, t2 = t - xsize / 3.0, t + xsize * 2 / 3.0 diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index c501dc8..120410c 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -750,10 +750,13 @@ def _panel_on_only_selection(self): updated_visibile_units = self.controller.get_visible_unit_ids() if set(current_visible_units) != set(updated_visibile_units): self._panel_refresh_colors() - # update the visible column - df = self.table.value - df.loc[self.controller.unit_ids, "visible"] = self.controller.get_units_visibility_mask() - self.table.value = df + # update the visible column in place (patch_column avoids resetting the + # table scroll position, which a full `self.table.value = df` would do) + self.table.patch_column( + "visible", + list(self.controller.get_units_visibility_mask()), + list(self.controller.unit_ids), + ) self.notify_unit_and_channel_visibility_changed() def _panel_get_selected_unit_ids(self): diff --git a/spikeinterface_gui/utils_panel.py b/spikeinterface_gui/utils_panel.py index 2cad88f..2f79476 100644 --- a/spikeinterface_gui/utils_panel.py +++ b/spikeinterface_gui/utils_panel.py @@ -5,6 +5,7 @@ except ImportError: from typing_extensions import NotRequired +import re import numpy as np import time import panel as pn @@ -63,7 +64,7 @@ unit_formatter = HTMLTemplateFormatter( template="""
- ● <%= value ? value.id : '' %> + ● <%= value ? value.id : '' %><%= value && value.n !== undefined ? ' n=' + value.n : '' %>
""" ) @@ -288,6 +289,8 @@ class SelectableTabulator(pn.viewable.Viewer): ---------- *args, **kwargs Arguments passed to the Tabulator constructor. + skip_sort_columns: list[str] + Columns to exclude from the "Sort by" dropdown options. parent_view: ViewBase | None The parent view that will be notified of selection changes. on_selection_changed: Callable | None @@ -314,6 +317,8 @@ def __init__( self._formatters = kwargs.get("formatters", {}) self._editors = kwargs.get("editors", {}) self._frozen_columns = kwargs.get("frozen_columns", []) + # columns to hide from the view but keep in the underlying dataframe + self._hidden_columns = list(kwargs.pop("hidden_columns", [])) self._selectable = kwargs.get("selectable", True) if "sortable" in kwargs: self._sortable = kwargs.pop("sortable") @@ -440,6 +445,7 @@ def refresh_tabulator_settings(self): self.tabulator.formatters = self._formatters self.tabulator.editors = self._editors self.tabulator.frozen_columns = self._frozen_columns + self.tabulator.hidden_columns = self._hidden_columns self.tabulator.selectable = self._selectable self.tabulator.sorters = [] @@ -478,10 +484,18 @@ def _on_sort_change(self, event): ascending=(self.direction_dropdown.value == "↑") ) else: - df = self.tabulator.value.sort_values( - by=self.sort_dropdown.value, - ascending=(self.direction_dropdown.value == "↑") + import pandas.api.types as ptypes + + col = self.sort_dropdown.value + sort_kwargs = dict( + by=col, + ascending=(self.direction_dropdown.value == "↑"), ) + if ptypes.is_string_dtype(self.tabulator.value[col]): + sort_kwargs["key"] = lambda x: x.map( + lambda v: [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', str(v))] + ) + df = self.tabulator.value.sort_values(**sort_kwargs) self.tabulator.value = df def _on_selection_change(self, event): diff --git a/spikeinterface_gui/utils_qt.py b/spikeinterface_gui/utils_qt.py index b4d35c1..2c14a4f 100644 --- a/spikeinterface_gui/utils_qt.py +++ b/spikeinterface_gui/utils_qt.py @@ -16,6 +16,11 @@ font-size: 10px; } +QPushButton:checked { + background-color: #3a6ea5; + color: white; +} + QComboBox{ min-width: 100px; max-width: 120px;