From d09c07281ddf6b003005865429dcc9a681d3fcb3 Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Wed, 8 Apr 2026 23:20:21 +0200 Subject: [PATCH 01/10] Reorganize tmp/ directory --- tmp/correlations/ed-1.py | 89 +++++++++++++++++++ tmp/{ => unsorted}/Untitled.ipynb | 0 tmp/{ => unsorted}/Untitled0.ipynb | 0 tmp/{ => unsorted}/Untitled1.ipynb | 0 tmp/{ => unsorted}/Untitled2.ipynb | 0 tmp/{ => unsorted}/__validator.py | 0 tmp/{ => unsorted}/_gemmi.py | 0 tmp/{ => unsorted}/_read_cif.py | 0 tmp/{ => unsorted}/_smart.py | 0 .../basic_single-fit_pd-neut-cwl_LBCO-HRPT.py | 0 ...truct_pd-neut-tof_multiphase-BSFTO-HRPT.py | 0 tmp/{ => unsorted}/display.py | 0 tmp/{ => unsorted}/display2.py | 0 tmp/{ => unsorted}/display3-Copy1.py | 0 tmp/{ => unsorted}/display3.py | 0 .../generate_overview_mermaid.py | 0 ...0p12Fe0p94Ti0p06O3_DW_V_9x8x52_1p49_HI.xye | 0 tmp/{ => unsorted}/short.py | 0 tmp/{ => unsorted}/short2.py | 0 tmp/{ => unsorted}/short3.py | 0 tmp/{ => unsorted}/short5.py | 0 tmp/{ => unsorted}/short6.py | 0 tmp/{ => unsorted}/short7.py | 0 tmp/{ => unsorted}/show_d401.py | 0 tmp/{ => unsorted}/show_w505.py | 0 ...test_single-fit_pd-neut-tof_Si-DREAM_nc.py | 0 26 files changed, 89 insertions(+) create mode 100644 tmp/correlations/ed-1.py rename tmp/{ => unsorted}/Untitled.ipynb (100%) rename tmp/{ => unsorted}/Untitled0.ipynb (100%) rename tmp/{ => unsorted}/Untitled1.ipynb (100%) rename tmp/{ => unsorted}/Untitled2.ipynb (100%) rename tmp/{ => unsorted}/__validator.py (100%) rename tmp/{ => unsorted}/_gemmi.py (100%) rename tmp/{ => unsorted}/_read_cif.py (100%) rename tmp/{ => unsorted}/_smart.py (100%) rename tmp/{ => unsorted}/basic_single-fit_pd-neut-cwl_LBCO-HRPT.py (100%) rename tmp/{ => unsorted}/cryst-struct_pd-neut-tof_multiphase-BSFTO-HRPT.py (100%) rename tmp/{ => unsorted}/display.py (100%) rename tmp/{ => unsorted}/display2.py (100%) rename tmp/{ => unsorted}/display3-Copy1.py (100%) rename tmp/{ => unsorted}/display3.py (100%) rename tmp/{ => unsorted}/generate_overview_mermaid.py (100%) rename tmp/{ => unsorted}/hrpt_n_Bi0p88Sm0p12Fe0p94Ti0p06O3_DW_V_9x8x52_1p49_HI.xye (100%) rename tmp/{ => unsorted}/short.py (100%) rename tmp/{ => unsorted}/short2.py (100%) rename tmp/{ => unsorted}/short3.py (100%) rename tmp/{ => unsorted}/short5.py (100%) rename tmp/{ => unsorted}/short6.py (100%) rename tmp/{ => unsorted}/short7.py (100%) rename tmp/{ => unsorted}/show_d401.py (100%) rename tmp/{ => unsorted}/show_w505.py (100%) rename tmp/{ => unsorted}/test_single-fit_pd-neut-tof_Si-DREAM_nc.py (100%) diff --git a/tmp/correlations/ed-1.py b/tmp/correlations/ed-1.py new file mode 100644 index 000000000..e654343d5 --- /dev/null +++ b/tmp/correlations/ed-1.py @@ -0,0 +1,89 @@ +# %% [markdown] +# # Structure Refinement: LBCO, HRPT +# +# This minimalistic example is designed to show how Rietveld refinement +# can be performed when both the crystal structure and experiment +# parameters are defined using CIF files. +# +# For this example, constant-wavelength neutron powder diffraction data +# for La0.5Ba0.5CoO3 from HRPT at PSI is used. +# +# It does not contain any advanced features or options, and includes no +# comments or explanations—these can be found in the other tutorials. +# Default values are used for all parameters if not specified. Only +# essential and self-explanatory code is provided. +# +# The example is intended for users who are already familiar with the +# EasyDiffraction library and want to quickly get started with a simple +# refinement. It is also useful for those who want to see what a +# refinement might look like in code. For a more detailed explanation of +# the code, please refer to the other tutorials. + +# %% [markdown] +# ## Import Library + +# %% +import easydiffraction as ed + +# %% [markdown] +# ## Step 1: Define Project + +# %% +# Create minimal project without name and description +project = ed.Project() + +# %% [markdown] +# ## Step 2: Define Crystal Structure + +# %% +# Download CIF file from repository +structure_path = ed.download_data(id=1, destination='data') + +# %% +# Add structure from downloaded CIF +project.structures.add_from_cif_path(structure_path) + +# %% [markdown] +# ## Step 3: Define Experiment + +# %% +# Download CIF file from repository +expt_path = ed.download_data(id=2, destination='data') + +# %% +# Add experiment from downloaded CIF +project.experiments.add_from_cif_path(expt_path) + +# %% [markdown] +# ## Step 4: Perform Analysis (cryspy) + +# %% +# Define aliases and constraints for refinement. This is necessary to +# properly refine the isotropic displacement parameters of La and Ba, +# which are correlated due to their shared Wyckoff position. +project.analysis.aliases.create( + label='biso_La', + param=project.structures['lbco'].atom_sites['La'].b_iso, +) +project.analysis.aliases.create( + label='biso_Ba', + param=project.structures['lbco'].atom_sites['Ba'].b_iso, +) +project.analysis.constraints.create(expression='biso_Ba = biso_La') + +# %% +# Start refinement. All parameters, which have standard uncertainties +# in the input CIF files, are refined by default. +project.analysis.fit() + +# %% +# Show fit results summary +project.analysis.display.fit_results() + +# %% +# Show defined experiment names +project.experiments.show_names() + +# %% +# Plot measured vs. calculated diffraction patterns +project.plotter.plot_meas_vs_calc(expt_name='hrpt', show_residual=True) diff --git a/tmp/Untitled.ipynb b/tmp/unsorted/Untitled.ipynb similarity index 100% rename from tmp/Untitled.ipynb rename to tmp/unsorted/Untitled.ipynb diff --git a/tmp/Untitled0.ipynb b/tmp/unsorted/Untitled0.ipynb similarity index 100% rename from tmp/Untitled0.ipynb rename to tmp/unsorted/Untitled0.ipynb diff --git a/tmp/Untitled1.ipynb b/tmp/unsorted/Untitled1.ipynb similarity index 100% rename from tmp/Untitled1.ipynb rename to tmp/unsorted/Untitled1.ipynb diff --git a/tmp/Untitled2.ipynb b/tmp/unsorted/Untitled2.ipynb similarity index 100% rename from tmp/Untitled2.ipynb rename to tmp/unsorted/Untitled2.ipynb diff --git a/tmp/__validator.py b/tmp/unsorted/__validator.py similarity index 100% rename from tmp/__validator.py rename to tmp/unsorted/__validator.py diff --git a/tmp/_gemmi.py b/tmp/unsorted/_gemmi.py similarity index 100% rename from tmp/_gemmi.py rename to tmp/unsorted/_gemmi.py diff --git a/tmp/_read_cif.py b/tmp/unsorted/_read_cif.py similarity index 100% rename from tmp/_read_cif.py rename to tmp/unsorted/_read_cif.py diff --git a/tmp/_smart.py b/tmp/unsorted/_smart.py similarity index 100% rename from tmp/_smart.py rename to tmp/unsorted/_smart.py diff --git a/tmp/basic_single-fit_pd-neut-cwl_LBCO-HRPT.py b/tmp/unsorted/basic_single-fit_pd-neut-cwl_LBCO-HRPT.py similarity index 100% rename from tmp/basic_single-fit_pd-neut-cwl_LBCO-HRPT.py rename to tmp/unsorted/basic_single-fit_pd-neut-cwl_LBCO-HRPT.py diff --git a/tmp/cryst-struct_pd-neut-tof_multiphase-BSFTO-HRPT.py b/tmp/unsorted/cryst-struct_pd-neut-tof_multiphase-BSFTO-HRPT.py similarity index 100% rename from tmp/cryst-struct_pd-neut-tof_multiphase-BSFTO-HRPT.py rename to tmp/unsorted/cryst-struct_pd-neut-tof_multiphase-BSFTO-HRPT.py diff --git a/tmp/display.py b/tmp/unsorted/display.py similarity index 100% rename from tmp/display.py rename to tmp/unsorted/display.py diff --git a/tmp/display2.py b/tmp/unsorted/display2.py similarity index 100% rename from tmp/display2.py rename to tmp/unsorted/display2.py diff --git a/tmp/display3-Copy1.py b/tmp/unsorted/display3-Copy1.py similarity index 100% rename from tmp/display3-Copy1.py rename to tmp/unsorted/display3-Copy1.py diff --git a/tmp/display3.py b/tmp/unsorted/display3.py similarity index 100% rename from tmp/display3.py rename to tmp/unsorted/display3.py diff --git a/tmp/generate_overview_mermaid.py b/tmp/unsorted/generate_overview_mermaid.py similarity index 100% rename from tmp/generate_overview_mermaid.py rename to tmp/unsorted/generate_overview_mermaid.py diff --git a/tmp/hrpt_n_Bi0p88Sm0p12Fe0p94Ti0p06O3_DW_V_9x8x52_1p49_HI.xye b/tmp/unsorted/hrpt_n_Bi0p88Sm0p12Fe0p94Ti0p06O3_DW_V_9x8x52_1p49_HI.xye similarity index 100% rename from tmp/hrpt_n_Bi0p88Sm0p12Fe0p94Ti0p06O3_DW_V_9x8x52_1p49_HI.xye rename to tmp/unsorted/hrpt_n_Bi0p88Sm0p12Fe0p94Ti0p06O3_DW_V_9x8x52_1p49_HI.xye diff --git a/tmp/short.py b/tmp/unsorted/short.py similarity index 100% rename from tmp/short.py rename to tmp/unsorted/short.py diff --git a/tmp/short2.py b/tmp/unsorted/short2.py similarity index 100% rename from tmp/short2.py rename to tmp/unsorted/short2.py diff --git a/tmp/short3.py b/tmp/unsorted/short3.py similarity index 100% rename from tmp/short3.py rename to tmp/unsorted/short3.py diff --git a/tmp/short5.py b/tmp/unsorted/short5.py similarity index 100% rename from tmp/short5.py rename to tmp/unsorted/short5.py diff --git a/tmp/short6.py b/tmp/unsorted/short6.py similarity index 100% rename from tmp/short6.py rename to tmp/unsorted/short6.py diff --git a/tmp/short7.py b/tmp/unsorted/short7.py similarity index 100% rename from tmp/short7.py rename to tmp/unsorted/short7.py diff --git a/tmp/show_d401.py b/tmp/unsorted/show_d401.py similarity index 100% rename from tmp/show_d401.py rename to tmp/unsorted/show_d401.py diff --git a/tmp/show_w505.py b/tmp/unsorted/show_w505.py similarity index 100% rename from tmp/show_w505.py rename to tmp/unsorted/show_w505.py diff --git a/tmp/test_single-fit_pd-neut-tof_Si-DREAM_nc.py b/tmp/unsorted/test_single-fit_pd-neut-tof_Si-DREAM_nc.py similarity index 100% rename from tmp/test_single-fit_pd-neut-tof_Si-DREAM_nc.py rename to tmp/unsorted/test_single-fit_pd-neut-tof_Si-DREAM_nc.py From 3532ab5644dbeb7e210aad6659070f29d5f35ebf Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 01:32:09 +0200 Subject: [PATCH 02/10] Implement parameter correlations table/plot --- .../display/plotters/plotly.py | 198 +++++++- src/easydiffraction/display/plotting.py | 480 ++++++++++++++++++ .../display/plotters/test_plotly.py | 36 ++ .../easydiffraction/display/test_plotting.py | 373 ++++++++++++++ tmp/correlations/ed-1.py | 32 +- 5 files changed, 1112 insertions(+), 7 deletions(-) diff --git a/src/easydiffraction/display/plotters/plotly.py b/src/easydiffraction/display/plotters/plotly.py index 77d351e1b..c772754a9 100644 --- a/src/easydiffraction/display/plotters/plotly.py +++ b/src/easydiffraction/display/plotters/plotly.py @@ -9,6 +9,7 @@ """ import darkdetect +import numpy as np import plotly.graph_objects as go import plotly.io as pio @@ -21,6 +22,8 @@ from easydiffraction.display.plotters.base import SERIES_CONFIG from easydiffraction.display.plotters.base import PlotterBase +from easydiffraction.utils._vendored.theme_detect import is_dark +from easydiffraction.utils.environment import in_jupyter from easydiffraction.utils.environment import in_pycharm DEFAULT_COLORS = { @@ -33,9 +36,198 @@ class PlotlyPlotter(PlotterBase): """Interactive plotter using Plotly for notebooks and browsers.""" - pio.templates.default = 'plotly_dark' if darkdetect.isDark() else 'plotly_white' - if in_pycharm(): - pio.renderers.default = 'browser' + def __init__(self) -> None: + if hasattr(pio, 'templates'): + pio.templates.default = self._default_template_name() + if in_pycharm(): + pio.renderers.default = 'browser' + + @staticmethod + def _is_dark_mode() -> bool: + """ + Return whether the active plotting context should use dark mode. + + In Jupyter, prefer notebook dark-mode detection. Outside + Jupyter, fall back to the system theme via ``darkdetect``. + + Returns + ------- + bool + ``True`` for dark mode, otherwise ``False``. + """ + return is_dark() if in_jupyter() else darkdetect.isDark() + + @classmethod + def _default_template_name(cls) -> str: + """ + Return the Plotly template matching the active theme. + + In Jupyter, prefer notebook dark-mode detection. Outside + Jupyter, fall back to the system theme via ``darkdetect``. + + Returns + ------- + str + Either ``'plotly_dark'`` or ``'plotly_white'``. + """ + return 'plotly_dark' if cls._is_dark_mode() else 'plotly_white' + + @classmethod + def _correlation_colorscale(cls) -> list[tuple[float, str]]: + """ + Return a diverging colorscale for correlation heatmaps. + + Dark mode uses black at zero correlation for lower visual + prominence. Light mode uses white at zero correlation. + + Returns + ------- + list[tuple[float, str]] + Plotly-compatible colorscale definition. + """ + if cls._is_dark_mode(): + return [ + (0.0, '#d73027'), + (0.5, '#000000'), + (1.0, '#4575b4'), + ] + return [ + (0.0, '#d73027'), + (0.5, '#f7f7f7'), + (1.0, '#4575b4'), + ] + + @classmethod + def _correlation_grid_color(cls) -> str: + """ + Return the boundary-line color for correlation heatmaps. + + Returns + ------- + str + RGBA color string tuned for the active theme. + """ + if cls._is_dark_mode(): + return 'rgba(110, 145, 190, 0.35)' + return 'rgba(120, 140, 160, 0.28)' + + def plot_correlation_heatmap( + self, + corr_df: object, + title: str, + ) -> None: + """ + Render a Plotly heatmap for a correlation matrix. + + Parameters + ---------- + corr_df : object + Square correlation DataFrame. + title : str + Figure title. + """ + num_rows, num_cols = corr_df.shape + x_edges = np.arange(num_cols + 1, dtype=float) + y_edges = np.arange(num_rows + 1, dtype=float) + x_centers = np.arange(num_cols, dtype=float) + 0.5 + y_centers = np.arange(num_rows, dtype=float) + 0.5 + grid_color = self._correlation_grid_color() + + heatmap = go.Heatmap( + z=corr_df.to_numpy(), + x=x_edges, + y=y_edges, + zmin=-1.0, + zmax=1.0, + zmid=0.0, + colorscale=self._correlation_colorscale(), + colorbar={ + 'title': {'text': ''}, + 'lenmode': 'fraction', + 'len': 1.0, + 'y': 0.5, + 'yanchor': 'middle', + }, + hoverongaps=False, + hovertemplate='x: %{x}
y: %{y}
corr: %{z:.3f}', + ) + + shapes = [ + { + 'type': 'line', + 'x0': float(x_pos), + 'x1': float(x_pos), + 'y0': 0.0, + 'y1': float(num_rows), + 'xref': 'x', + 'yref': 'y', + 'layer': 'above', + 'line': {'color': grid_color, 'width': 1}, + } + for x_pos in x_edges[1:-1] + ] + shapes.extend( + { + 'type': 'line', + 'x0': 0.0, + 'x1': float(num_cols), + 'y0': float(y_pos), + 'y1': float(y_pos), + 'xref': 'x', + 'yref': 'y', + 'layer': 'above', + 'line': {'color': grid_color, 'width': 1}, + } + for y_pos in y_edges[1:-1] + ) + shapes.append({ + 'type': 'rect', + 'x0': 0.0, + 'x1': 1.0, + 'y0': 0.0, + 'y1': 1.0, + 'xref': 'paper', + 'yref': 'paper', + 'layer': 'above', + 'line': {'color': grid_color, 'width': 1}, + 'fillcolor': 'rgba(0, 0, 0, 0)', + }) + + layout = self._get_layout( + title, + ['Parameter', 'Parameter'], + shapes=shapes, + ) + fig = self._get_figure([heatmap], layout) + fig.update_xaxes( + side='bottom', + tickangle=-45, + automargin=True, + tickmode='array', + tickvals=x_centers.tolist(), + ticktext=corr_df.columns.tolist(), + range=[0.0, float(num_cols)], + showgrid=False, + showline=False, + mirror=False, + ticks='', + layer='above traces', + ) + fig.update_yaxes( + autorange='reversed', + automargin=True, + tickmode='array', + tickvals=y_centers.tolist(), + ticktext=corr_df.index.tolist(), + ticklabelstandoff=8, + range=[float(num_rows), 0.0], + showgrid=False, + showline=False, + mirror=False, + ticks='', + layer='above traces', + ) + self._show_figure(fig) @staticmethod def _get_powder_trace( diff --git a/src/easydiffraction/display/plotting.py b/src/easydiffraction/display/plotting.py index 98c546d74..40add3da3 100644 --- a/src/easydiffraction/display/plotting.py +++ b/src/easydiffraction/display/plotting.py @@ -53,6 +53,10 @@ def description(self) -> str: return '' +DEFAULT_CORRELATION_THRESHOLD = 0.7 +EXPECTED_COVAR_NDIM = 2 + + class Plotter(RendererBase): """User-facing plotting facade backed by concrete plotters.""" @@ -498,6 +502,482 @@ def plot_param_series( self._project.analysis._parameter_snapshots, ) + def plot_param_correlations( + self, + threshold: float | None = DEFAULT_CORRELATION_THRESHOLD, + show_diagonal: bool = False, + triangle: str = 'lower', + precision: int = 2, + ) -> None: + """ + Plot the parameter correlation matrix from the latest fit. + + The matrix is taken from ``project.analysis.fit_results``. When + the active engine is Plotly, an interactive heatmap is shown. + Otherwise, a rounded correlation table is rendered. + + Parameters + ---------- + threshold : float | None, default=DEFAULT_CORRELATION_THRESHOLD + Minimum absolute off-diagonal correlation required for a + parameter to be shown. Parameters are kept only if they + participate in at least one pair with ``abs(correlation) >= + threshold``. Set to ``None`` or ``0`` to show the full + matrix. + show_diagonal : bool, default=False + Whether to show self-correlations on the diagonal. The + default hides them because they are always ``1`` and do not + add information. + triangle : str, default='lower' + Which half of the symmetric matrix to show. Supported values + are ``'lower'``, ``'upper'``, and ``'full'``. + precision : int, default=2 + Number of decimal places to show in the table fallback. + """ + corr_df = self._get_param_correlation_dataframe() + if corr_df is None: + return + + corr_df = self._filter_correlation_dataframe(corr_df, threshold=threshold) + if corr_df is None: + return + + corr_df = self._mask_correlation_triangle( + corr_df, + triangle=triangle, + show_diagonal=show_diagonal, + ) + title = 'Refined parameter correlation matrix' + if threshold is not None and threshold > 0: + title += f' with |correlation| >= {threshold:.2f}' + + is_plotly = self._engine == PlotterEngineEnum.PLOTLY.value and isinstance( + self._backend, PlotlyPlotter + ) + display_corr_df, row_numbers, col_numbers = self._trim_correlation_display_dataframe( + corr_df, + triangle=triangle, + show_diagonal=show_diagonal, + preserve_all_rows=not is_plotly, + ) + + if is_plotly: + self._plot_correlation_heatmap(display_corr_df, title) + return + + console.paragraph(title) + TableRenderer.get().render( + self._format_correlation_table_dataframe( + display_corr_df, + row_numbers=row_numbers, + col_numbers=col_numbers, + threshold=threshold, + precision=precision, + ) + ) + + @staticmethod + def _filter_correlation_dataframe( + corr_df: pd.DataFrame, + threshold: float | None, + ) -> pd.DataFrame | None: + """ + Filter a correlation matrix to only strongly correlated params. + + Parameters + ---------- + corr_df : pd.DataFrame + Square correlation matrix. + threshold : float | None + Absolute-correlation cutoff. ``None`` or ``0`` keeps all + parameters. + + Returns + ------- + pd.DataFrame | None + Filtered square matrix, or ``None`` if no off-diagonal + correlations meet the cutoff. + + Raises + ------ + ValueError + If *threshold* is outside ``[0, 1]``. + """ + if threshold is None or threshold <= 0: + return corr_df + if threshold > 1: + msg = 'Correlation threshold must be between 0 and 1.' + raise ValueError(msg) + + abs_corr = np.abs(corr_df.to_numpy(copy=True)) + np.fill_diagonal(abs_corr, 0.0) + keep_mask = (abs_corr >= threshold).any(axis=0) + + if not keep_mask.any(): + log.warning(f'No parameter pairs with |correlation| >= {threshold:.2f} were found.') + return None + + labels = corr_df.index[keep_mask] + return corr_df.loc[labels, labels] + + @staticmethod + def _mask_correlation_triangle( + corr_df: pd.DataFrame, + triangle: str, + show_diagonal: bool, + ) -> pd.DataFrame: + """ + Mask the unused half of the symmetric correlation matrix. + + Parameters + ---------- + corr_df : pd.DataFrame + Square correlation matrix. + triangle : str + Which part of the matrix to keep: ``'lower'``, ``'upper'``, + or ``'full'``. + show_diagonal : bool + Whether to keep the diagonal values visible. + + Returns + ------- + pd.DataFrame + Correlation matrix with unused cells masked. + + Raises + ------ + ValueError + If *triangle* is unsupported. + """ + if triangle not in {'lower', 'upper', 'full'}: + msg = "Correlation triangle must be 'lower', 'upper', or 'full'." + raise ValueError(msg) + + masked_values = corr_df.to_numpy(copy=True) + k = 1 if show_diagonal else 0 + + if triangle == 'lower': + mask = np.triu(np.ones_like(masked_values, dtype=bool), k=k) + masked_values[mask] = np.nan + elif triangle == 'upper': + mask = np.tril(np.ones_like(masked_values, dtype=bool), k=-k) + masked_values[mask] = np.nan + elif not show_diagonal: + diag_idx = np.diag_indices_from(masked_values) + masked_values[diag_idx] = np.nan + + return pd.DataFrame(masked_values, index=corr_df.index, columns=corr_df.columns) + + @staticmethod + def _trim_correlation_display_dataframe( + corr_df: pd.DataFrame, + triangle: str, + show_diagonal: bool, + preserve_all_rows: bool, + ) -> tuple[pd.DataFrame, list[int], list[int]]: + """ + Trim empty outer rows/columns from triangle views. + + Parameters + ---------- + corr_df : pd.DataFrame + Masked correlation matrix. + triangle : str + Which triangle is shown. + show_diagonal : bool + Whether diagonal values are visible. + preserve_all_rows : bool + Whether to keep the full row list so row labels continue to + identify all numeric column headers in tabular output. + + Returns + ------- + tuple[pd.DataFrame, list[int], list[int]] + Display matrix plus 1-based parameter numbers for the kept + rows and columns. + """ + num_rows, num_cols = corr_df.shape + row_numbers = list(range(1, num_rows + 1)) + col_numbers = list(range(1, num_cols + 1)) + + if show_diagonal or triangle == 'full' or min(num_rows, num_cols) <= 1: + return corr_df, row_numbers, col_numbers + + if triangle == 'lower': + if preserve_all_rows: + return corr_df.iloc[:, :-1], row_numbers, col_numbers[:-1] + return corr_df.iloc[1:, :-1], row_numbers[1:], col_numbers[:-1] + if triangle == 'upper': + if preserve_all_rows: + return corr_df.iloc[:, 1:], row_numbers, col_numbers[1:] + return corr_df.iloc[:-1, 1:], row_numbers[:-1], col_numbers[1:] + + return corr_df, row_numbers, col_numbers + + def _get_param_correlation_dataframe(self) -> pd.DataFrame | None: + """ + Return the correlation matrix for the latest fit. + + Returns + ------- + pd.DataFrame | None + Square correlation matrix labeled by parameter unique names, + or ``None`` if unavailable. + """ + result = self._get_fit_result_for_correlation() + if result is None: + return None + raw_result, var_names, fit_results = result + + covar = getattr(raw_result, 'covar', None) + if covar is not None: + return self._correlation_from_covariance(covar, var_names, fit_results.parameters) + + corr_df = self._get_param_correlation_dataframe_from_engine_params( + raw_result=raw_result, + parameters=fit_results.parameters, + ) + if corr_df is not None: + return corr_df + + log.warning( + 'Correlation matrix is unavailable for this fit. ' + 'Use the lmfit minimizer and ensure covariance estimation succeeds.' + ) + return None + + def _get_fit_result_for_correlation( + self, + ) -> tuple[object, list[str], object] | None: + """ + Validate and return the raw fit result for correlation. + + Returns + ------- + tuple[object, list[str], object] | None + A tuple of ``(raw_result, var_names, fit_results)`` when all + required data is present, or ``None`` otherwise. + """ + if self._project is None: + log.warning('Plotter is not attached to a project.') + return None + + fit_results = getattr(self._project.analysis, 'fit_results', None) + if fit_results is None: + log.warning('No fit results available. Run fit() first.') + return None + + raw_result = getattr(fit_results, 'engine_result', None) + if raw_result is None: + log.warning('No raw fit result available. Correlation matrix cannot be plotted.') + return None + + var_names = getattr(raw_result, 'var_names', None) + if not var_names: + log.warning('Fit result does not expose variable names for a correlation matrix.') + return None + + return raw_result, var_names, fit_results + + @staticmethod + def _correlation_from_covariance( + covar: object, + var_names: list[str], + parameters: list[object], + ) -> pd.DataFrame | None: + """ + Convert a covariance matrix to a correlation DataFrame. + + Parameters + ---------- + covar : object + Raw covariance matrix from the fit result. + var_names : list[str] + Minimizer variable names. + parameters : list[object] + Fitted parameter descriptors. + + Returns + ------- + pd.DataFrame | None + Correlation matrix, or ``None`` if the covariance is + invalid. + """ + covar_array = np.asarray(covar, dtype=float) + if covar_array.ndim != EXPECTED_COVAR_NDIM or covar_array.shape[0] != covar_array.shape[1]: + log.warning('Fit result returned an invalid covariance matrix.') + return None + if covar_array.shape[0] != len(var_names): + log.warning('Covariance matrix size does not match the fitted parameter list.') + return None + + sigma = np.sqrt(np.diag(covar_array)) + with np.errstate(divide='ignore', invalid='ignore'): + corr = covar_array / np.outer(sigma, sigma) + corr = np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0) + np.fill_diagonal(corr, 1.0) + + labels = Plotter._get_correlation_labels(parameters, var_names) + return pd.DataFrame(corr, index=labels, columns=labels) + + @staticmethod + def _get_correlation_labels( + parameters: list[object], + var_names: list[str], + ) -> list[str]: + """ + Map minimizer variable names to readable parameter labels. + + Parameters + ---------- + parameters : list[object] + Fitted parameter descriptors. + var_names : list[str] + Minimizer variable names from the engine result. + + Returns + ------- + list[str] + Labels for the correlation matrix axes. + """ + labels_by_uid = { + getattr(param, '_minimizer_uid', ''): getattr( + param, 'unique_name', getattr(param, 'name', '') + ) + for param in parameters + } + return [labels_by_uid.get(name, name) for name in var_names] + + def _get_param_correlation_dataframe_from_engine_params( + self, + raw_result: object, + parameters: list[object], + ) -> pd.DataFrame | None: + """ + Reconstruct a correlation matrix from engine parameter metadata. + + This is a fallback for backends that populate per-parameter + correlation coefficients but do not expose a covariance matrix. + + Parameters + ---------- + raw_result : object + Backend-specific fit result. + parameters : list[object] + Fitted parameter descriptors. + + Returns + ------- + pd.DataFrame | None + Correlation matrix labeled by readable parameter names, or + ``None`` if no correlation coefficients are available. + """ + engine_params = getattr(raw_result, 'params', None) + var_names = getattr(raw_result, 'var_names', None) + if engine_params is None or not var_names: + return None + + corr = np.eye(len(var_names), dtype=float) + indices = {name: idx for idx, name in enumerate(var_names)} + found_corr = False + + for name, idx in indices.items(): + engine_param = engine_params.get(name) + param_corr = getattr(engine_param, 'correl', None) + if not param_corr: + continue + + for other_name, value in param_corr.items(): + other_idx = indices.get(other_name) + if other_idx is None: + continue + corr_value = float(value) + corr[idx, other_idx] = corr_value + corr[other_idx, idx] = corr_value + found_corr = True + + if not found_corr: + return None + + labels = self._get_correlation_labels(parameters, var_names) + return pd.DataFrame(corr, index=labels, columns=labels) + + def _plot_correlation_heatmap( + self, + corr_df: pd.DataFrame, + title: str, + ) -> None: + """ + Delegate correlation heatmap rendering to the Plotly backend. + + Parameters + ---------- + corr_df : pd.DataFrame + Square correlation matrix. + title : str + Figure title. + """ + self._backend.plot_correlation_heatmap(corr_df, title) + + @staticmethod + def _format_correlation_table_dataframe( + corr_df: pd.DataFrame, + row_numbers: list[int], + col_numbers: list[int], + threshold: float | None, + precision: int, + ) -> pd.DataFrame: + """ + Format a correlation matrix for TableRenderer. + + Parameters + ---------- + corr_df : pd.DataFrame + Correlation matrix labeled by parameter name. + row_numbers : list[int] + 1-based parameter numbers for displayed rows. + col_numbers : list[int] + 1-based parameter numbers for displayed columns. + threshold : float | None + Absolute-correlation cutoff used to blank low-magnitude + cells in the rendered table. ``None`` or ``0`` keeps all + non-masked values. + precision : int + Number of decimals to show in the rendered values. + + Returns + ------- + pd.DataFrame + DataFrame with MultiIndex columns and default numeric index, + suitable for :class:`TableRenderer`. Correlation columns use + 1-based numeric headers so they line up with the numbered + parameter rows in terminal output. + """ + rounded = corr_df.round(precision) + cell_width = max( + len(str(max(col_numbers, default=0))), + len(f'{-1.0:.{precision}f}'), + ) + headers = [('parameter', 'left')] + headers.extend((str(index).rjust(cell_width), 'right') for index in col_numbers) + + rows = [] + for label, values in rounded.iterrows(): + row_values = [] + for value in values.tolist(): + should_blank = pd.isna(value) or ( + threshold is not None and threshold > 0 and abs(float(value)) < threshold + ) + if should_blank: + row_values.append('') + else: + row_values.append(f'{float(value):>{cell_width}.{precision}f}') + rows.append([label, *row_values]) + + df = pd.DataFrame(rows, columns=pd.MultiIndex.from_tuples(headers)) + df.index = pd.Index([row_number - 1 for row_number in row_numbers]) + return df + def _plot_meas_data( self, pattern: object, diff --git a/tests/unit/easydiffraction/display/plotters/test_plotly.py b/tests/unit/easydiffraction/display/plotters/test_plotly.py index fbfd6d4ed..5335a6871 100644 --- a/tests/unit/easydiffraction/display/plotters/test_plotly.py +++ b/tests/unit/easydiffraction/display/plotters/test_plotly.py @@ -10,6 +10,42 @@ def test_module_import(): assert expected_module_name == actual_module_name +def test_default_template_name_prefers_jupyter_theme(monkeypatch): + import easydiffraction.display.plotters.plotly as pp + + monkeypatch.setattr(pp, 'in_jupyter', lambda: True) + monkeypatch.setattr(pp, 'is_dark', lambda: True) + monkeypatch.setattr(pp.darkdetect, 'isDark', lambda: False) + + assert pp.PlotlyPlotter._default_template_name() == 'plotly_dark' + + +def test_correlation_colorscale_uses_black_center_in_dark_mode(monkeypatch): + import easydiffraction.display.plotters.plotly as pp + + monkeypatch.setattr(pp.PlotlyPlotter, '_is_dark_mode', staticmethod(lambda: True)) + + assert pp.PlotlyPlotter._correlation_colorscale()[1] == (0.5, '#000000') + + +def test_default_template_name_uses_system_theme_outside_jupyter(monkeypatch): + import easydiffraction.display.plotters.plotly as pp + + monkeypatch.setattr(pp, 'in_jupyter', lambda: False) + monkeypatch.setattr(pp, 'is_dark', lambda: False) + monkeypatch.setattr(pp.darkdetect, 'isDark', lambda: False) + + assert pp.PlotlyPlotter._default_template_name() == 'plotly_white' + + +def test_correlation_colorscale_uses_white_center_in_light_mode(monkeypatch): + import easydiffraction.display.plotters.plotly as pp + + monkeypatch.setattr(pp.PlotlyPlotter, '_is_dark_mode', staticmethod(lambda: False)) + + assert pp.PlotlyPlotter._correlation_colorscale()[1] == (0.5, '#f7f7f7') + + def test_get_trace_and_plot(monkeypatch): import easydiffraction.display.plotters.plotly as pp diff --git a/tests/unit/easydiffraction/display/test_plotting.py b/tests/unit/easydiffraction/display/test_plotting.py index 840a61dff..3f5b79d68 100644 --- a/tests/unit/easydiffraction/display/test_plotting.py +++ b/tests/unit/easydiffraction/display/test_plotting.py @@ -166,3 +166,376 @@ def __init__(self): p._plot_meas_data(Ptn(), 'E', ExptType()) assert called['labels'] == ('meas',) assert 'Measured data' in called['title'] + + +def test_plot_param_correlations_renders_ascii_table(monkeypatch): + import numpy as np + + from easydiffraction.display.plotting import Plotter + from easydiffraction.display.tables import TableRenderer + + captured = {} + + class FakeTabler: + def render(self, df): + captured['df'] = df + + monkeypatch.setattr(TableRenderer, 'get', staticmethod(lambda: FakeTabler())) + + class Param: + def __init__(self, uid, unique_name): + self._minimizer_uid = uid + self.unique_name = unique_name + + class RawResult: + covar = np.array([[4.0, 1.0], [1.0, 9.0]]) + var_names = ['p1', 'p2'] + + class FitResults: + engine_result = RawResult() + parameters = [ + Param('p1', 'phase.cell.length_a'), + Param('p2', 'phase.cell.length_b'), + ] + + class Analysis: + fit_results = FitResults() + + class Project: + analysis = Analysis() + + p = Plotter() + p.engine = 'asciichartpy' + p._set_project(Project()) + p.plot_param_correlations(threshold=0.1, precision=3) + + df = captured['df'] + assert [column.strip() for column in df.columns.get_level_values(0)] == [ + 'parameter', + '1', + ] + assert list(df.columns.get_level_values(1)) == ['left', 'right'] + assert list(df.index) == [0, 1] + assert df.iloc[0, 0] == 'phase.cell.length_a' + assert df.iloc[0, 1] == '' + assert df.iloc[1, 0] == 'phase.cell.length_b' + assert df.iloc[1, 1].strip() == '0.167' + + +def test_plot_param_correlations_renders_plotly_heatmap(monkeypatch): + import numpy as np + + import easydiffraction.display.plotters.plotly as plotly_mod + from easydiffraction.display.plotting import Plotter + + captured = {} + + def fake_show_figure(self, fig): + captured['fig'] = fig + + monkeypatch.setattr(plotly_mod.PlotlyPlotter, '_show_figure', fake_show_figure) + + class Param: + def __init__(self, uid, unique_name): + self._minimizer_uid = uid + self.unique_name = unique_name + + class RawResult: + covar = np.array([[1.0, -0.5], [-0.5, 1.0]]) + var_names = ['p1', 'p2'] + + class FitResults: + engine_result = RawResult() + parameters = [ + Param('p1', 'phase.scale'), + Param('p2', 'phase.cell.length_c'), + ] + + class Analysis: + fit_results = FitResults() + + class Project: + analysis = Analysis() + + p = Plotter() + p.engine = 'plotly' + p._set_project(Project()) + p.plot_param_correlations(threshold=0.1) + + fig = captured['fig'] + assert fig.data[0].type == 'heatmap' + assert list(fig.data[0].x) == [0.0, 1.0] + assert list(fig.data[0].y) == [0.0, 1.0] + assert fig.data[0].xgap in (None, 0) + assert fig.data[0].ygap in (None, 0) + assert fig.data[0].colorbar.lenmode == 'fraction' + assert fig.data[0].colorbar.len == 1.0 + assert fig.data[0].colorbar.title.text == '' + assert pytest.approx(fig.data[0].z[0][0], rel=1e-9) == -0.5 + assert fig.layout.xaxis.side == 'bottom' + assert fig.layout.xaxis.tickangle == -45 + assert list(fig.layout.xaxis.tickvals) == [0.5] + assert list(fig.layout.xaxis.ticktext) == ['phase.scale'] + assert fig.layout.xaxis.showline is False + assert fig.layout.xaxis.mirror is False + assert fig.layout.xaxis.layer == 'above traces' + assert list(fig.layout.yaxis.tickvals) == [0.5] + assert list(fig.layout.yaxis.ticktext) == ['phase.cell.length_c'] + assert fig.layout.yaxis.showline is False + assert fig.layout.yaxis.mirror is False + assert fig.layout.yaxis.layer == 'above traces' + assert fig.layout.yaxis.ticklabelstandoff == 8 + assert len(fig.layout.shapes) == 1 + assert fig.layout.shapes[-1].type == 'rect' + assert fig.layout.shapes[-1].xref == 'paper' + assert fig.layout.shapes[-1].yref == 'paper' + + +def test_plot_param_correlations_can_show_diagonal(monkeypatch): + import numpy as np + + from easydiffraction.display.plotting import Plotter + from easydiffraction.display.tables import TableRenderer + + captured = {} + + class FakeTabler: + def render(self, df): + captured['df'] = df + + monkeypatch.setattr(TableRenderer, 'get', staticmethod(lambda: FakeTabler())) + + class Param: + def __init__(self, uid, unique_name): + self._minimizer_uid = uid + self.unique_name = unique_name + + class RawResult: + covar = np.array([[1.0, 0.4], [0.4, 1.0]]) + var_names = ['p1', 'p2'] + + class FitResults: + engine_result = RawResult() + parameters = [ + Param('p1', 'phase.scale'), + Param('p2', 'phase.cell.length_c'), + ] + + class Analysis: + fit_results = FitResults() + + class Project: + analysis = Analysis() + + p = Plotter() + p.engine = 'asciichartpy' + p._set_project(Project()) + p.plot_param_correlations(threshold=0.1, show_diagonal=True) + + df = captured['df'] + assert df.iloc[0, 1].strip() == '1.00' + assert df.iloc[0, 2] == '' + assert df.iloc[1, 2].strip() == '1.00' + + +def test_plot_param_correlations_can_show_full_matrix(monkeypatch): + import numpy as np + + from easydiffraction.display.plotting import Plotter + from easydiffraction.display.tables import TableRenderer + + captured = {} + + class FakeTabler: + def render(self, df): + captured['df'] = df + + monkeypatch.setattr(TableRenderer, 'get', staticmethod(lambda: FakeTabler())) + + class Param: + def __init__(self, uid, unique_name): + self._minimizer_uid = uid + self.unique_name = unique_name + + class RawResult: + covar = np.array([[1.0, 0.4], [0.4, 1.0]]) + var_names = ['p1', 'p2'] + + class FitResults: + engine_result = RawResult() + parameters = [ + Param('p1', 'phase.scale'), + Param('p2', 'phase.cell.length_c'), + ] + + class Analysis: + fit_results = FitResults() + + class Project: + analysis = Analysis() + + p = Plotter() + p.engine = 'asciichartpy' + p._set_project(Project()) + p.plot_param_correlations(threshold=0.1, triangle='full', show_diagonal=False) + + df = captured['df'] + assert df.iloc[0, 1] == '' + assert df.iloc[0, 2].strip() == '0.40' + assert df.iloc[1, 1].strip() == '0.40' + assert df.iloc[1, 2] == '' + + +def test_plot_param_correlations_filters_by_default_threshold(monkeypatch): + from easydiffraction.display.plotting import Plotter + from easydiffraction.display.tables import TableRenderer + + captured = {} + + class FakeTabler: + def render(self, df): + captured['df'] = df + + monkeypatch.setattr(TableRenderer, 'get', staticmethod(lambda: FakeTabler())) + + class Param: + def __init__(self, uid, unique_name): + self._minimizer_uid = uid + self.unique_name = unique_name + + class RawResult: + covar = None + var_names = ['p1', 'p2', 'p3'] + + class ParamResult: + def __init__(self, correl): + self.correl = correl + + params = { + 'p1': ParamResult({'p2': 0.82}), + 'p2': ParamResult({'p1': 0.82}), + 'p3': ParamResult({'p1': 0.25}), + } + + class FitResults: + engine_result = RawResult() + parameters = [ + Param('p1', 'phase.scale'), + Param('p2', 'phase.cell.length_a'), + Param('p3', 'phase.background'), + ] + + class Analysis: + fit_results = FitResults() + + class Project: + analysis = Analysis() + + p = Plotter() + p.engine = 'asciichartpy' + p._set_project(Project()) + p.plot_param_correlations() + + df = captured['df'] + assert [column.strip() for column in df.columns.get_level_values(0)] == [ + 'parameter', + '1', + ] + assert list(df.index) == [0, 1] + assert df.iloc[0, 0] == 'phase.scale' + assert df.iloc[0, 1] == '' + assert df.iloc[1, 0] == 'phase.cell.length_a' + assert df.iloc[1, 1].strip() == '0.82' + + +def test_plot_param_correlations_hides_subthreshold_table_values(monkeypatch): + from easydiffraction.display.plotting import Plotter + from easydiffraction.display.tables import TableRenderer + + captured = {} + + class FakeTabler: + def render(self, df): + captured['df'] = df + + monkeypatch.setattr(TableRenderer, 'get', staticmethod(lambda: FakeTabler())) + + class Param: + def __init__(self, uid, unique_name): + self._minimizer_uid = uid + self.unique_name = unique_name + + class RawResult: + covar = None + var_names = ['p1', 'p2', 'p3', 'p4', 'p5'] + + class ParamResult: + def __init__(self, correl): + self.correl = correl + + params = { + 'p1': ParamResult({'p4': 0.02, 'p5': 0.82}), + 'p2': ParamResult({'p3': -0.91, 'p5': 0.02}), + 'p3': ParamResult({'p2': -0.91, 'p4': -0.89, 'p5': -0.01}), + 'p4': ParamResult({'p1': 0.02, 'p3': -0.89, 'p5': 0.01}), + 'p5': ParamResult({'p1': 0.82, 'p2': 0.02, 'p3': -0.01, 'p4': 0.01}), + } + + class FitResults: + engine_result = RawResult() + parameters = [ + Param('p1', 'lbco.cell.length_a'), + Param('p2', 'hrpt.peak.broad_gauss_u'), + Param('p3', 'hrpt.peak.broad_gauss_v'), + Param('p4', 'hrpt.peak.broad_gauss_w'), + Param('p5', 'hrpt.instrument.twotheta_offset'), + ] + + class Analysis: + fit_results = FitResults() + + class Project: + analysis = Analysis() + + p = Plotter() + p.engine = 'asciichartpy' + p._set_project(Project()) + p.plot_param_correlations() + + df = captured['df'] + assert [column.strip() for column in df.columns.get_level_values(0)] == [ + 'parameter', + '1', + '2', + '3', + '4', + ] + assert list(df.index) == [0, 1, 2, 3, 4] + assert df.iloc[0, 1] == '' + assert df.iloc[1, 1] == '' + assert df.iloc[2, 1] == '' + assert df.iloc[2, 2].strip() == '-0.91' + assert df.iloc[3, 1] == '' + assert df.iloc[3, 2] == '' + assert df.iloc[3, 3].strip() == '-0.89' + assert df.iloc[4, 1].strip() == '0.82' + assert df.iloc[4, 2] == '' + assert df.iloc[4, 3] == '' + assert df.iloc[4, 4] == '' + + +def test_plot_param_correlations_threshold_validation(): + from easydiffraction.display.plotting import Plotter + + with pytest.raises(ValueError, match='between 0 and 1'): + Plotter._filter_correlation_dataframe(object(), threshold=1.1) + + +def test_plot_param_correlations_triangle_validation(): + import pandas as pd + + from easydiffraction.display.plotting import Plotter + + df = pd.DataFrame([[1.0]]) + with pytest.raises(ValueError, match='lower'): + Plotter._mask_correlation_triangle(df, triangle='sideways', show_diagonal=False) diff --git a/tmp/correlations/ed-1.py b/tmp/correlations/ed-1.py index e654343d5..100f0ac0f 100644 --- a/tmp/correlations/ed-1.py +++ b/tmp/correlations/ed-1.py @@ -55,12 +55,32 @@ project.experiments.add_from_cif_path(expt_path) # %% [markdown] -# ## Step 4: Perform Analysis (cryspy) +# ## Step 4: Perform Analysis (no constraints) # %% -# Define aliases and constraints for refinement. This is necessary to -# properly refine the isotropic displacement parameters of La and Ba, -# which are correlated due to their shared Wyckoff position. +# Start refinement. All parameters, which have standard uncertainties +# in the input CIF files, are refined by default. +project.analysis.fit() + +# %% +# Show fit results summary +project.analysis.display.fit_results() + +# %% +# Show parameter correlations +project.plotter.plot_param_correlations() + +# %% [markdown] +# ## Step 5: Perform Analysis (with constraints) + +# %% +# As can be seen from the parameter-correlation plot, the isotropic +# displacement parameters of La and Ba are highly correlated. Because +# La and Ba share the same mixed-occupancy site, their contributions to +# the neutron diffraction pattern are difficult to separate, especially +# since their coherent scattering lengths are not very different. +# Therefore, it is necessary to constrain them to be equal. First we +# define aliases and then use them to create a constraint. project.analysis.aliases.create( label='biso_La', param=project.structures['lbco'].atom_sites['La'].b_iso, @@ -80,6 +100,10 @@ # Show fit results summary project.analysis.display.fit_results() +# %% +# Show parameter correlations +project.plotter.plot_param_correlations() + # %% # Show defined experiment names project.experiments.show_names() From 7e99d1f942f36002688d5ef63a395ad61f23338d Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 08:56:12 +0200 Subject: [PATCH 03/10] Enhance correlation heatmap with threshold and precision parameters for labels --- .../display/plotters/plotly.py | 95 ++++++++++++++++++- src/easydiffraction/display/plotting.py | 20 +++- .../easydiffraction/display/test_plotting.py | 72 +++++++++++++- 3 files changed, 181 insertions(+), 6 deletions(-) diff --git a/src/easydiffraction/display/plotters/plotly.py b/src/easydiffraction/display/plotters/plotly.py index c772754a9..df9d263c9 100644 --- a/src/easydiffraction/display/plotters/plotly.py +++ b/src/easydiffraction/display/plotters/plotly.py @@ -115,6 +115,8 @@ def plot_correlation_heatmap( self, corr_df: object, title: str, + threshold: float | None, + precision: int, ) -> None: """ Render a Plotly heatmap for a correlation matrix. @@ -125,6 +127,10 @@ def plot_correlation_heatmap( Square correlation DataFrame. title : str Figure title. + threshold : float | None + Absolute-correlation cutoff used for value labels. + precision : int + Number of decimals to show in labels and hover text. """ num_rows, num_cols = corr_df.shape x_edges = np.arange(num_cols + 1, dtype=float) @@ -149,7 +155,15 @@ def plot_correlation_heatmap( 'yanchor': 'middle', }, hoverongaps=False, - hovertemplate='x: %{x}
y: %{y}
corr: %{z:.3f}', + hovertemplate=f'x: %{{x}}
y: %{{y}}
correlation: %{{z:' + f'.{precision}f}}', + ) + label_trace = self._get_correlation_label_trace( + corr_df, + x_centers=x_centers, + y_centers=y_centers, + threshold=threshold, + precision=precision, ) shapes = [ @@ -198,10 +212,13 @@ def plot_correlation_heatmap( ['Parameter', 'Parameter'], shapes=shapes, ) - fig = self._get_figure([heatmap], layout) + traces = [heatmap] + if label_trace is not None: + traces.append(label_trace) + fig = self._get_figure(traces, layout) fig.update_xaxes( side='bottom', - tickangle=-45, + tickangle=-10, automargin=True, tickmode='array', tickvals=x_centers.tolist(), @@ -229,6 +246,78 @@ def plot_correlation_heatmap( ) self._show_figure(fig) + @classmethod + def _correlation_label_color(cls) -> str: + """ + Return the text color used for in-cell correlation labels. + + Returns + ------- + str + Hex color string. + """ + return '#f5f5f5' + + @classmethod + def _get_correlation_label_trace( + cls, + corr_df: object, + x_centers: np.ndarray, + y_centers: np.ndarray, + threshold: float | None, + precision: int, + ) -> object | None: + """ + Build a text trace for visible correlation values. + + Parameters + ---------- + corr_df : object + Correlation DataFrame to annotate. + x_centers : np.ndarray + Cell center x coordinates. + y_centers : np.ndarray + Cell center y coordinates. + threshold : float | None + Minimum absolute correlation required for a label. + precision : int + Number of decimals for rendered labels. + + Returns + ------- + object | None + Plotly text trace, or ``None`` when no labels should be + shown. + """ + values = corr_df.to_numpy() + label_x = [] + label_y = [] + label_text = [] + + for row_idx, row in enumerate(values): + for col_idx, value in enumerate(row): + if np.isnan(value): + continue + if threshold is not None and threshold > 0 and abs(float(value)) < threshold: + continue + label_x.append(float(x_centers[col_idx])) + label_y.append(float(y_centers[row_idx])) + label_text.append(f'{float(value):.{precision}f}') + + if not label_text: + return None + + return go.Scatter( + x=label_x, + y=label_y, + mode='text', + text=label_text, + textposition='middle center', + textfont={'color': cls._correlation_label_color()}, + hoverinfo='skip', + showlegend=False, + ) + @staticmethod def _get_powder_trace( x: object, diff --git a/src/easydiffraction/display/plotting.py b/src/easydiffraction/display/plotting.py index 40add3da3..49d6b4ce2 100644 --- a/src/easydiffraction/display/plotting.py +++ b/src/easydiffraction/display/plotting.py @@ -562,7 +562,12 @@ def plot_param_correlations( ) if is_plotly: - self._plot_correlation_heatmap(display_corr_df, title) + self._plot_correlation_heatmap( + display_corr_df, + title, + threshold=threshold, + precision=precision, + ) return console.paragraph(title) @@ -906,6 +911,8 @@ def _plot_correlation_heatmap( self, corr_df: pd.DataFrame, title: str, + threshold: float | None, + precision: int, ) -> None: """ Delegate correlation heatmap rendering to the Plotly backend. @@ -916,8 +923,17 @@ def _plot_correlation_heatmap( Square correlation matrix. title : str Figure title. + threshold : float | None + Absolute-correlation cutoff used for value labels. + precision : int + Number of decimals to show in plot labels and hover text. """ - self._backend.plot_correlation_heatmap(corr_df, title) + self._backend.plot_correlation_heatmap( + corr_df, + title, + threshold=threshold, + precision=precision, + ) @staticmethod def _format_correlation_table_dataframe( diff --git a/tests/unit/easydiffraction/display/test_plotting.py b/tests/unit/easydiffraction/display/test_plotting.py index 3f5b79d68..66b90f90d 100644 --- a/tests/unit/easydiffraction/display/test_plotting.py +++ b/tests/unit/easydiffraction/display/test_plotting.py @@ -263,6 +263,7 @@ class Project: p.plot_param_correlations(threshold=0.1) fig = captured['fig'] + assert len(fig.data) == 2 assert fig.data[0].type == 'heatmap' assert list(fig.data[0].x) == [0.0, 1.0] assert list(fig.data[0].y) == [0.0, 1.0] @@ -271,9 +272,17 @@ class Project: assert fig.data[0].colorbar.lenmode == 'fraction' assert fig.data[0].colorbar.len == 1.0 assert fig.data[0].colorbar.title.text == '' + assert fig.data[0].hovertemplate == 'x: %{x}
y: %{y}
corr: %{z:.2f}' assert pytest.approx(fig.data[0].z[0][0], rel=1e-9) == -0.5 + assert fig.data[1].type == 'scatter' + assert fig.data[1].mode == 'text' + assert list(fig.data[1].x) == [0.5] + assert list(fig.data[1].y) == [0.5] + assert list(fig.data[1].text) == ['-0.50'] + assert fig.data[1].textposition == 'middle center' + assert fig.data[1].hoverinfo == 'skip' assert fig.layout.xaxis.side == 'bottom' - assert fig.layout.xaxis.tickangle == -45 + assert fig.layout.xaxis.tickangle < 0 assert list(fig.layout.xaxis.tickvals) == [0.5] assert list(fig.layout.xaxis.ticktext) == ['phase.scale'] assert fig.layout.xaxis.showline is False @@ -291,6 +300,67 @@ class Project: assert fig.layout.shapes[-1].yref == 'paper' +def test_plot_param_correlations_plotly_labels_respect_threshold(monkeypatch): + import easydiffraction.display.plotters.plotly as plotly_mod + from easydiffraction.display.plotting import Plotter + + captured = {} + + def fake_show_figure(self, fig): + captured['fig'] = fig + + monkeypatch.setattr(plotly_mod.PlotlyPlotter, '_show_figure', fake_show_figure) + + class Param: + def __init__(self, uid, unique_name): + self._minimizer_uid = uid + self.unique_name = unique_name + + class RawResult: + covar = None + var_names = ['p1', 'p2', 'p3', 'p4', 'p5'] + + class ParamResult: + def __init__(self, correl): + self.correl = correl + + params = { + 'p1': ParamResult({'p4': 0.02, 'p5': 0.82}), + 'p2': ParamResult({'p3': -0.91, 'p4': 0.83, 'p5': 0.02}), + 'p3': ParamResult({'p2': -0.91, 'p4': -0.89, 'p5': -0.01}), + 'p4': ParamResult({'p1': 0.02, 'p2': 0.83, 'p3': -0.89, 'p5': 0.01}), + 'p5': ParamResult({'p1': 0.82, 'p2': 0.02, 'p3': -0.01, 'p4': 0.01}), + } + + class FitResults: + engine_result = RawResult() + parameters = [ + Param('p1', 'lbco.cell.length_a'), + Param('p2', 'hrpt.peak.broad_gauss_u'), + Param('p3', 'hrpt.peak.broad_gauss_v'), + Param('p4', 'hrpt.peak.broad_gauss_w'), + Param('p5', 'hrpt.instrument.twotheta_offset'), + ] + + class Analysis: + fit_results = FitResults() + + class Project: + analysis = Analysis() + + p = Plotter() + p.engine = 'plotly' + p._set_project(Project()) + p.plot_param_correlations() + + fig = captured['fig'] + assert len(fig.data) == 2 + assert fig.data[0].type == 'heatmap' + assert fig.data[1].type == 'scatter' + assert fig.data[1].mode == 'text' + assert list(fig.data[1].text) == ['-0.91', '0.83', '-0.89', '0.82'] + + def test_plot_param_correlations_can_show_diagonal(monkeypatch): import numpy as np From 7e3ddd0c8d220a1e2184034ae09963f9266ec303 Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 08:56:34 +0200 Subject: [PATCH 04/10] Refine documentation and examples --- docs/docs/tutorials/ed-1.py | 73 ++++++++++++++++++------------------ docs/docs/tutorials/ed-18.py | 7 +++- docs/docs/tutorials/ed-2.py | 2 +- docs/docs/tutorials/index.md | 12 +++--- docs/mkdocs.yml | 2 +- 5 files changed, 50 insertions(+), 46 deletions(-) diff --git a/docs/docs/tutorials/ed-1.py b/docs/docs/tutorials/ed-1.py index d74c7eea3..3fcd9aab7 100644 --- a/docs/docs/tutorials/ed-1.py +++ b/docs/docs/tutorials/ed-1.py @@ -1,23 +1,20 @@ # %% [markdown] # # Structure Refinement: LBCO, HRPT # -# This minimalistic example is designed to show how Rietveld refinement -# can be performed when both the crystal structure and experiment -# parameters are defined using CIF files. +# This basic example is designed to show how Rietveld refinement can be +# performed when both the crystal structure and experiment parameters +# are defined using CIF files. # # For this example, constant-wavelength neutron powder diffraction data # for La0.5Ba0.5CoO3 from HRPT at PSI is used. # -# It does not contain any advanced features or options, and includes no -# comments or explanations—these can be found in the other tutorials. -# Default values are used for all parameters if not specified. Only -# essential and self-explanatory code is provided. -# # The example is intended for users who are already familiar with the -# EasyDiffraction library and want to quickly get started with a simple -# refinement. It is also useful for those who want to see what a -# refinement might look like in code. For a more detailed explanation of -# the code, please refer to the other tutorials. +# EasyDiffraction library and want to quickly get started with a basic +# refinement. +# +# It is also useful for those who want to see how constraints can be +# applied to highly correlated parameters. For a more detailed +# explanation of the code, please refer to the other tutorials. # %% [markdown] # ## Import Library @@ -55,12 +52,32 @@ project.experiments.add_from_cif_path(expt_path) # %% [markdown] -# ## Step 4: Perform Analysis (cryspy) +# ## Step 4: Perform Analysis (no constraints) + +# %% +# Start refinement. All parameters, which have standard uncertainties +# in the input CIF files, are refined by default. +project.analysis.fit() + +# %% +# Show fit results summary +project.analysis.display.fit_results() + +# %% +# Show parameter correlations +project.plotter.plot_param_correlations() + +# %% [markdown] +# ## Step 5: Perform Analysis (with constraints) # %% -# Define aliases and constraints for refinement. This is necessary to -# properly refine the isotropic displacement parameters of La and Ba, -# which are correlated due to their shared Wyckoff position. +# As can be seen from the parameter-correlation plot, the isotropic +# displacement parameters of La and Ba are highly correlated. Because +# La and Ba share the same mixed-occupancy site, their contributions to +# the neutron diffraction pattern are difficult to separate, especially +# since their coherent scattering lengths are not very different. +# Therefore, it is necessary to constrain them to be equal. First we +# define aliases and then use them to create a constraint. project.analysis.aliases.create( label='biso_La', param=project.structures['lbco'].atom_sites['La'].b_iso, @@ -81,28 +98,12 @@ project.analysis.display.fit_results() # %% -# Show defined experiment names -project.experiments.show_names() - -# %% -# Plot measured vs. calculated diffraction patterns -project.plotter.plot_meas_vs_calc(expt_name='hrpt', show_residual=True) - -# %% [markdown] -# ## Step 5: Perform Analysis (crysfml) - -# %% -# Change calculation engine from 'cryspy' to 'crysfml' -project.experiments['hrpt'].show_supported_calculator_types() -project.experiments['hrpt'].calculator_type = 'crysfml' - -# %% -# Start refinement -project.analysis.fit() +# Show parameter correlations +project.plotter.plot_param_correlations() # %% -# Show fit results summary -project.analysis.display.fit_results() +# Show defined experiment names +project.experiments.show_names() # %% # Plot measured vs. calculated diffraction patterns diff --git a/docs/docs/tutorials/ed-18.py b/docs/docs/tutorials/ed-18.py index 5b2c4b673..fbfd39e58 100644 --- a/docs/docs/tutorials/ed-18.py +++ b/docs/docs/tutorials/ed-18.py @@ -5,8 +5,11 @@ # how to load a previously saved project from a directory and run # refinement — all in just a few lines of code. # -# For details on how to define structures and experiments, see the other -# tutorials. +# For this example, constant-wavelength neutron powder diffraction data +# for La0.5Ba0.5CoO3 from HRPT at PSI is used. +# +# It does not contain any advanced features or options, and includes no +# comments or explanations — these can be found in the other tutorials. # %% [markdown] # ## Import Modules diff --git a/docs/docs/tutorials/ed-2.py b/docs/docs/tutorials/ed-2.py index 4dd78389a..609f54db9 100644 --- a/docs/docs/tutorials/ed-2.py +++ b/docs/docs/tutorials/ed-2.py @@ -10,7 +10,7 @@ # for La0.5Ba0.5CoO3 from HRPT at PSI is used. # # It does not contain any advanced features or options, and includes no -# comments or explanations—these can be found in the other tutorials. +# comments or explanations — these can be found in the other tutorials. # Default values are used for all parameters if not specified. Only # essential and self-explanatory code is provided. # diff --git a/docs/docs/tutorials/index.md b/docs/docs/tutorials/index.md index 914bb8592..45beb04bc 100644 --- a/docs/docs/tutorials/index.md +++ b/docs/docs/tutorials/index.md @@ -21,18 +21,18 @@ The tutorials are organized into the following categories: how to load a previously saved project from a directory and run refinement. Useful when a project has already been set up and saved in a prior session. -- [LBCO `quick` CIF](ed-1.ipynb) – A minimal example intended as a quick - reference for users already familiar with the EasyDiffraction API or - who want to see how Rietveld refinement of the La0.5Ba0.5CoO3 crystal - structure can be performed when both the structure and experiment are - loaded from CIF files. Data collected from constant wavelength neutron - powder diffraction at HRPT at PSI. - [LBCO `quick` `code`](ed-2.ipynb) – A minimal example intended as a quick reference for users already familiar with the EasyDiffraction API or who want to see an example refinement when both the structure and experiment are defined directly in code. This tutorial covers a Rietveld refinement of the La0.5Ba0.5CoO3 crystal structure using constant wavelength neutron powder diffraction data from HRPT at PSI. +- [LBCO `basic` `load`](ed-1.ipynb) – A basic example intended as a + quick reference for users already familiar with the EasyDiffraction + API or who want to see how Rietveld refinement of the La0.5Ba0.5CoO3 + crystal structure can be performed when both the structure and + experiment are loaded from CIF files. Data collected from constant + wavelength neutron powder diffraction at HRPT at PSI. - [LBCO `complete`](ed-3.ipynb) – Demonstrates the use of the EasyDiffraction API in a simplified, user-friendly manner that closely follows the GUI workflow for a Rietveld refinement of the diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 662650291..10964636c 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -192,8 +192,8 @@ nav: - Tutorials: tutorials/index.md - Getting Started: - LBCO quick load: tutorials/ed-18.ipynb - - LBCO quick CIF: tutorials/ed-1.ipynb - LBCO quick code: tutorials/ed-2.ipynb + - LBCO basic load: tutorials/ed-1.ipynb - LBCO complete: tutorials/ed-3.ipynb - Powder Diffraction: - Co2SiO4 pd-neut-cwl: tutorials/ed-5.ipynb From 2615689764d9f40e6195a546631b1112b5e320d7 Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 08:56:42 +0200 Subject: [PATCH 05/10] Refactor example notebooks to improve clarity --- docs/docs/tutorials/ed-1.ipynb | 95 ++++++++++--------- docs/docs/tutorials/ed-10.ipynb | 2 +- docs/docs/tutorials/ed-11.ipynb | 2 +- docs/docs/tutorials/ed-12.ipynb | 2 +- docs/docs/tutorials/ed-13.ipynb | 4 +- docs/docs/tutorials/ed-14.ipynb | 2 +- docs/docs/tutorials/ed-15.ipynb | 2 +- docs/docs/tutorials/ed-16.ipynb | 2 +- docs/docs/tutorials/ed-17.ipynb | 2 +- docs/docs/tutorials/ed-18.ipynb | 9 +- docs/docs/tutorials/ed-2.ipynb | 4 +- docs/docs/tutorials/ed-3.ipynb | 2 +- docs/docs/tutorials/ed-4.ipynb | 2 +- docs/docs/tutorials/ed-5.ipynb | 2 +- docs/docs/tutorials/ed-6.ipynb | 2 +- docs/docs/tutorials/ed-7.ipynb | 2 +- docs/docs/tutorials/ed-8.ipynb | 2 +- docs/docs/tutorials/ed-9.ipynb | 2 +- .../display/plotters/plotly.py | 3 +- 19 files changed, 73 insertions(+), 70 deletions(-) diff --git a/docs/docs/tutorials/ed-1.ipynb b/docs/docs/tutorials/ed-1.ipynb index 740e2c020..604fe72da 100644 --- a/docs/docs/tutorials/ed-1.ipynb +++ b/docs/docs/tutorials/ed-1.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b1e6b328", + "id": "ab0fa7f5", "metadata": { "tags": [ "hide-in-docs" @@ -26,23 +26,20 @@ "source": [ "# Structure Refinement: LBCO, HRPT\n", "\n", - "This minimalistic example is designed to show how Rietveld refinement\n", - "can be performed when both the crystal structure and experiment\n", - "parameters are defined using CIF files.\n", + "This basic example is designed to show how Rietveld refinement can be\n", + "performed when both the crystal structure and experiment parameters\n", + "are defined using CIF files.\n", "\n", "For this example, constant-wavelength neutron powder diffraction data\n", "for La0.5Ba0.5CoO3 from HRPT at PSI is used.\n", "\n", - "It does not contain any advanced features or options, and includes no\n", - "comments or explanations—these can be found in the other tutorials.\n", - "Default values are used for all parameters if not specified. Only\n", - "essential and self-explanatory code is provided.\n", - "\n", "The example is intended for users who are already familiar with the\n", - "EasyDiffraction library and want to quickly get started with a simple\n", - "refinement. It is also useful for those who want to see what a\n", - "refinement might look like in code. For a more detailed explanation of\n", - "the code, please refer to the other tutorials." + "EasyDiffraction library and want to quickly get started with a basic\n", + "refinement.\n", + "\n", + "It is also useful for those who want to see how constraints can be\n", + "applied to highly correlated parameters. For a more detailed\n", + "explanation of the code, please refer to the other tutorials." ] }, { @@ -147,7 +144,7 @@ "id": "11", "metadata": {}, "source": [ - "## Step 4: Perform Analysis (cryspy)" + "## Step 4: Perform Analysis (no constraints)" ] }, { @@ -157,18 +154,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Define aliases and constraints for refinement. This is necessary to\n", - "# properly refine the isotropic displacement parameters of La and Ba,\n", - "# which are correlated due to their shared Wyckoff position.\n", - "project.analysis.aliases.create(\n", - " label='biso_La',\n", - " param=project.structures['lbco'].atom_sites['La'].b_iso,\n", - ")\n", - "project.analysis.aliases.create(\n", - " label='biso_Ba',\n", - " param=project.structures['lbco'].atom_sites['Ba'].b_iso,\n", - ")\n", - "project.analysis.constraints.create(expression='biso_Ba = biso_La')" + "# Start refinement. All parameters, which have standard uncertainties\n", + "# in the input CIF files, are refined by default.\n", + "project.analysis.fit()" ] }, { @@ -178,9 +166,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Start refinement. All parameters, which have standard uncertainties\n", - "# in the input CIF files, are refined by default.\n", - "project.analysis.fit()" + "# Show fit results summary\n", + "project.analysis.display.fit_results()" ] }, { @@ -190,19 +177,16 @@ "metadata": {}, "outputs": [], "source": [ - "# Show fit results summary\n", - "project.analysis.display.fit_results()" + "# Show parameter correlations\n", + "project.plotter.plot_param_correlations()" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "id": "15", "metadata": {}, - "outputs": [], "source": [ - "# Show defined experiment names\n", - "project.experiments.show_names()" + "## Step 5: Perform Analysis (with constraints)" ] }, { @@ -212,16 +196,34 @@ "metadata": {}, "outputs": [], "source": [ - "# Plot measured vs. calculated diffraction patterns\n", - "project.plotter.plot_meas_vs_calc(expt_name='hrpt', show_residual=True)" + "# As can be seen from the parameter-correlation plot, the isotropic\n", + "# displacement parameters of La and Ba are highly correlated. Because\n", + "# La and Ba share the same mixed-occupancy site, their contributions to\n", + "# the neutron diffraction pattern are difficult to separate, especially\n", + "# since their coherent scattering lengths are not very different.\n", + "# Therefore, it is necessary to constrain them to be equal. First we\n", + "# define aliases and then use them to create a constraint.\n", + "project.analysis.aliases.create(\n", + " label='biso_La',\n", + " param=project.structures['lbco'].atom_sites['La'].b_iso,\n", + ")\n", + "project.analysis.aliases.create(\n", + " label='biso_Ba',\n", + " param=project.structures['lbco'].atom_sites['Ba'].b_iso,\n", + ")\n", + "project.analysis.constraints.create(expression='biso_Ba = biso_La')" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "17", "metadata": {}, + "outputs": [], "source": [ - "## Step 5: Perform Analysis (crysfml)" + "# Start refinement. All parameters, which have standard uncertainties\n", + "# in the input CIF files, are refined by default.\n", + "project.analysis.fit()" ] }, { @@ -231,9 +233,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Change calculation engine from 'cryspy' to 'crysfml'\n", - "project.experiments['hrpt'].show_supported_calculator_types()\n", - "project.experiments['hrpt'].calculator_type = 'crysfml'" + "# Show fit results summary\n", + "project.analysis.display.fit_results()" ] }, { @@ -243,8 +244,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Start refinement\n", - "project.analysis.fit()" + "# Show parameter correlations\n", + "project.plotter.plot_param_correlations()" ] }, { @@ -254,8 +255,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Show fit results summary\n", - "project.analysis.display.fit_results()" + "# Show defined experiment names\n", + "project.experiments.show_names()" ] }, { diff --git a/docs/docs/tutorials/ed-10.ipynb b/docs/docs/tutorials/ed-10.ipynb index 46489bd63..b0cc149c4 100644 --- a/docs/docs/tutorials/ed-10.ipynb +++ b/docs/docs/tutorials/ed-10.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "307e00fa", + "id": "dc88a9ac", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-11.ipynb b/docs/docs/tutorials/ed-11.ipynb index c0bbdd682..7aba32c38 100644 --- a/docs/docs/tutorials/ed-11.ipynb +++ b/docs/docs/tutorials/ed-11.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ab104e6", + "id": "9736040b", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-12.ipynb b/docs/docs/tutorials/ed-12.ipynb index d0913caee..75d6c515c 100644 --- a/docs/docs/tutorials/ed-12.ipynb +++ b/docs/docs/tutorials/ed-12.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f188a46a", + "id": "44960a4c", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-13.ipynb b/docs/docs/tutorials/ed-13.ipynb index 549d91678..254d49e00 100644 --- a/docs/docs/tutorials/ed-13.ipynb +++ b/docs/docs/tutorials/ed-13.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be0cfa00", + "id": "1ac207fe", "metadata": { "tags": [ "hide-in-docs" @@ -2647,7 +2647,7 @@ ], "metadata": { "jupytext": { - "cell_metadata_filter": "title,tags,-all", + "cell_metadata_filter": "tags,title,-all", "main_language": "python", "notebook_metadata_filter": "-all" } diff --git a/docs/docs/tutorials/ed-14.ipynb b/docs/docs/tutorials/ed-14.ipynb index 88c2077f3..9d2a9c5b0 100644 --- a/docs/docs/tutorials/ed-14.ipynb +++ b/docs/docs/tutorials/ed-14.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6f4b6a6", + "id": "86c9f966", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-15.ipynb b/docs/docs/tutorials/ed-15.ipynb index a6e4e7a03..12e88c10a 100644 --- a/docs/docs/tutorials/ed-15.ipynb +++ b/docs/docs/tutorials/ed-15.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ceaaff89", + "id": "c631fd19", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-16.ipynb b/docs/docs/tutorials/ed-16.ipynb index 2fa0dd21e..68dd7243f 100644 --- a/docs/docs/tutorials/ed-16.ipynb +++ b/docs/docs/tutorials/ed-16.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c342992c", + "id": "d0fcc613", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-17.ipynb b/docs/docs/tutorials/ed-17.ipynb index 3844add89..871d18890 100644 --- a/docs/docs/tutorials/ed-17.ipynb +++ b/docs/docs/tutorials/ed-17.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47b804b0", + "id": "0c20dcfb", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-18.ipynb b/docs/docs/tutorials/ed-18.ipynb index 082ab9efb..655a0b808 100644 --- a/docs/docs/tutorials/ed-18.ipynb +++ b/docs/docs/tutorials/ed-18.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15fcf6ce", + "id": "f9d269b5", "metadata": { "tags": [ "hide-in-docs" @@ -30,8 +30,11 @@ "how to load a previously saved project from a directory and run\n", "refinement — all in just a few lines of code.\n", "\n", - "For details on how to define structures and experiments, see the other\n", - "tutorials." + "For this example, constant-wavelength neutron powder diffraction data\n", + "for La0.5Ba0.5CoO3 from HRPT at PSI is used.\n", + "\n", + "It does not contain any advanced features or options, and includes no\n", + "comments or explanations — these can be found in the other tutorials." ] }, { diff --git a/docs/docs/tutorials/ed-2.ipynb b/docs/docs/tutorials/ed-2.ipynb index 0357de4ab..3b359bcb1 100644 --- a/docs/docs/tutorials/ed-2.ipynb +++ b/docs/docs/tutorials/ed-2.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15c2d211", + "id": "9e18bf0f", "metadata": { "tags": [ "hide-in-docs" @@ -35,7 +35,7 @@ "for La0.5Ba0.5CoO3 from HRPT at PSI is used.\n", "\n", "It does not contain any advanced features or options, and includes no\n", - "comments or explanations—these can be found in the other tutorials.\n", + "comments or explanations — these can be found in the other tutorials.\n", "Default values are used for all parameters if not specified. Only\n", "essential and self-explanatory code is provided.\n", "\n", diff --git a/docs/docs/tutorials/ed-3.ipynb b/docs/docs/tutorials/ed-3.ipynb index 8eade20f7..7b4099b1d 100644 --- a/docs/docs/tutorials/ed-3.ipynb +++ b/docs/docs/tutorials/ed-3.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7be48772", + "id": "3e463ed0", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-4.ipynb b/docs/docs/tutorials/ed-4.ipynb index e331438e2..82ecbcc2f 100644 --- a/docs/docs/tutorials/ed-4.ipynb +++ b/docs/docs/tutorials/ed-4.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "05436c40", + "id": "d969fc08", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-5.ipynb b/docs/docs/tutorials/ed-5.ipynb index 8c524eaf7..9f4f53310 100644 --- a/docs/docs/tutorials/ed-5.ipynb +++ b/docs/docs/tutorials/ed-5.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ce91035", + "id": "645a5335", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-6.ipynb b/docs/docs/tutorials/ed-6.ipynb index f1b7ad5db..be47740f1 100644 --- a/docs/docs/tutorials/ed-6.ipynb +++ b/docs/docs/tutorials/ed-6.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e53f4485", + "id": "fdc5cf61", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-7.ipynb b/docs/docs/tutorials/ed-7.ipynb index 57e2d7265..b80018e28 100644 --- a/docs/docs/tutorials/ed-7.ipynb +++ b/docs/docs/tutorials/ed-7.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "788753a7", + "id": "e195a39c", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-8.ipynb b/docs/docs/tutorials/ed-8.ipynb index 7b9cf3781..45a9f8381 100644 --- a/docs/docs/tutorials/ed-8.ipynb +++ b/docs/docs/tutorials/ed-8.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9e6c622d", + "id": "68e98133", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-9.ipynb b/docs/docs/tutorials/ed-9.ipynb index 2de9787f0..31579906e 100644 --- a/docs/docs/tutorials/ed-9.ipynb +++ b/docs/docs/tutorials/ed-9.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "336ce5f0", + "id": "f854b55b", "metadata": { "tags": [ "hide-in-docs" diff --git a/src/easydiffraction/display/plotters/plotly.py b/src/easydiffraction/display/plotters/plotly.py index df9d263c9..f6b16edbd 100644 --- a/src/easydiffraction/display/plotters/plotly.py +++ b/src/easydiffraction/display/plotters/plotly.py @@ -155,8 +155,7 @@ def plot_correlation_heatmap( 'yanchor': 'middle', }, hoverongaps=False, - hovertemplate=f'x: %{{x}}
y: %{{y}}
correlation: %{{z:' - f'.{precision}f}}', + hovertemplate=f'x: %{{x}}
y: %{{y}}
corr: %{{z:.{precision}f}}', ) label_trace = self._get_correlation_label_trace( corr_df, From fc34be653283e9f422c855ad4455fcc30b418912 Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 09:20:30 +0200 Subject: [PATCH 06/10] Temporarily lower minimum coverage percentage requirement from 75 to 70 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 58f6e668e..a2b9456df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ source = ['src'] # Limit coverage to the source code directory [tool.coverage.report] show_missing = true # Show missing lines skip_covered = false # Skip files with 100% coverage in the report -fail_under = 75 # Minimum coverage percentage to pass +fail_under = 70 # Minimum coverage percentage to pass ########################## # Configuration for pytest From 846ca15c599e92957126d6a29e4ae194630ab5cc Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 11:09:47 +0200 Subject: [PATCH 07/10] Remove show_diagonal and triangle from plot_param_correlations --- pixi.lock | 4 +- src/easydiffraction/display/plotters/base.py | 30 +++++ .../display/plotters/plotly.py | 2 + src/easydiffraction/display/plotting.py | 95 ++++------------ .../easydiffraction/display/test_plotting.py | 105 ------------------ 5 files changed, 58 insertions(+), 178 deletions(-) diff --git a/pixi.lock b/pixi.lock index 15c4b18ea..060895c9c 100644 --- a/pixi.lock +++ b/pixi.lock @@ -4511,8 +4511,8 @@ packages: requires_python: '>=3.5' - pypi: ./ name: easydiffraction - version: 0.11.1+devdirty43 - sha256: 2c841c32ccaac8a714fc637b889352caa4c4b22afb54b89f1305402a5bf8d574 + version: 0.11.1+dev15 + sha256: bdcda04c826721a0f6b38e12a82a543e37e84eb8d9365d20204bc9cc65cca981 requires_dist: - asciichartpy - asteval diff --git a/src/easydiffraction/display/plotters/base.py b/src/easydiffraction/display/plotters/base.py index d8ad2b485..2c0fb40c1 100644 --- a/src/easydiffraction/display/plotters/base.py +++ b/src/easydiffraction/display/plotters/base.py @@ -164,6 +164,8 @@ class PlotterBase(ABC): calculated values (e.g., F²meas vs F²calc for single crystal). """ + _supports_graphical_heatmap: bool = False + @abstractmethod def plot_powder( self, @@ -256,3 +258,31 @@ def plot_scatter( height : int | None Backend-specific height (text rows or pixels). """ + + def plot_correlation_heatmap( + self, + corr_df: object, + title: str, + threshold: float | None, + precision: int, + ) -> None: + """ + Render a graphical heatmap for a correlation matrix. + + The default implementation does nothing. Graphical backends + (e.g. Plotly) override this method and set + ``_supports_graphical_heatmap = True`` so the facade knows a + heatmap was rendered. + + Parameters + ---------- + corr_df : object + Square correlation DataFrame. + title : str + Figure title. + threshold : float | None + Absolute-correlation cutoff used for value labels. + precision : int + Number of decimals to show in labels and hover text. + """ + return diff --git a/src/easydiffraction/display/plotters/plotly.py b/src/easydiffraction/display/plotters/plotly.py index f6b16edbd..8dd9a3391 100644 --- a/src/easydiffraction/display/plotters/plotly.py +++ b/src/easydiffraction/display/plotters/plotly.py @@ -36,6 +36,8 @@ class PlotlyPlotter(PlotterBase): """Interactive plotter using Plotly for notebooks and browsers.""" + _supports_graphical_heatmap: bool = True + def __init__(self) -> None: if hasattr(pio, 'templates'): pio.templates.default = self._default_template_name() diff --git a/src/easydiffraction/display/plotting.py b/src/easydiffraction/display/plotting.py index 49d6b4ce2..ec5f04584 100644 --- a/src/easydiffraction/display/plotting.py +++ b/src/easydiffraction/display/plotting.py @@ -505,8 +505,6 @@ def plot_param_series( def plot_param_correlations( self, threshold: float | None = DEFAULT_CORRELATION_THRESHOLD, - show_diagonal: bool = False, - triangle: str = 'lower', precision: int = 2, ) -> None: """ @@ -516,6 +514,9 @@ def plot_param_correlations( the active engine is Plotly, an interactive heatmap is shown. Otherwise, a rounded correlation table is rendered. + Only the lower triangle is shown (without the diagonal), since + the matrix is symmetric and diagonal values are always ``1``. + Parameters ---------- threshold : float | None, default=DEFAULT_CORRELATION_THRESHOLD @@ -524,13 +525,6 @@ def plot_param_correlations( participate in at least one pair with ``abs(correlation) >= threshold``. Set to ``None`` or ``0`` to show the full matrix. - show_diagonal : bool, default=False - Whether to show self-correlations on the diagonal. The - default hides them because they are always ``1`` and do not - add information. - triangle : str, default='lower' - Which half of the symmetric matrix to show. Supported values - are ``'lower'``, ``'upper'``, and ``'full'``. precision : int, default=2 Number of decimal places to show in the table fallback. """ @@ -542,26 +536,18 @@ def plot_param_correlations( if corr_df is None: return - corr_df = self._mask_correlation_triangle( - corr_df, - triangle=triangle, - show_diagonal=show_diagonal, - ) + corr_df = self._mask_correlation_lower_triangle(corr_df) title = 'Refined parameter correlation matrix' if threshold is not None and threshold > 0: title += f' with |correlation| >= {threshold:.2f}' - is_plotly = self._engine == PlotterEngineEnum.PLOTLY.value and isinstance( - self._backend, PlotlyPlotter - ) + is_graphical = self._backend._supports_graphical_heatmap display_corr_df, row_numbers, col_numbers = self._trim_correlation_display_dataframe( corr_df, - triangle=triangle, - show_diagonal=show_diagonal, - preserve_all_rows=not is_plotly, + preserve_all_rows=not is_graphical, ) - if is_plotly: + if is_graphical: self._plot_correlation_heatmap( display_corr_df, title, @@ -626,71 +612,45 @@ def _filter_correlation_dataframe( return corr_df.loc[labels, labels] @staticmethod - def _mask_correlation_triangle( + def _mask_correlation_lower_triangle( corr_df: pd.DataFrame, - triangle: str, - show_diagonal: bool, ) -> pd.DataFrame: """ - Mask the unused half of the symmetric correlation matrix. + Mask the upper triangle and diagonal of a correlation matrix. + + Only the lower triangle is kept, since the matrix is symmetric + and diagonal values are always ``1``. Parameters ---------- corr_df : pd.DataFrame Square correlation matrix. - triangle : str - Which part of the matrix to keep: ``'lower'``, ``'upper'``, - or ``'full'``. - show_diagonal : bool - Whether to keep the diagonal values visible. Returns ------- pd.DataFrame - Correlation matrix with unused cells masked. - - Raises - ------ - ValueError - If *triangle* is unsupported. + Correlation matrix with upper triangle and diagonal masked. """ - if triangle not in {'lower', 'upper', 'full'}: - msg = "Correlation triangle must be 'lower', 'upper', or 'full'." - raise ValueError(msg) - masked_values = corr_df.to_numpy(copy=True) - k = 1 if show_diagonal else 0 - - if triangle == 'lower': - mask = np.triu(np.ones_like(masked_values, dtype=bool), k=k) - masked_values[mask] = np.nan - elif triangle == 'upper': - mask = np.tril(np.ones_like(masked_values, dtype=bool), k=-k) - masked_values[mask] = np.nan - elif not show_diagonal: - diag_idx = np.diag_indices_from(masked_values) - masked_values[diag_idx] = np.nan - + mask = np.triu(np.ones_like(masked_values, dtype=bool), k=0) + masked_values[mask] = np.nan return pd.DataFrame(masked_values, index=corr_df.index, columns=corr_df.columns) @staticmethod def _trim_correlation_display_dataframe( corr_df: pd.DataFrame, - triangle: str, - show_diagonal: bool, preserve_all_rows: bool, ) -> tuple[pd.DataFrame, list[int], list[int]]: """ - Trim empty outer rows/columns from triangle views. + Trim empty outer rows/columns from the lower-triangle view. + + For the lower triangle without diagonal, the last column and + first row are always empty and can be trimmed. Parameters ---------- corr_df : pd.DataFrame Masked correlation matrix. - triangle : str - Which triangle is shown. - show_diagonal : bool - Whether diagonal values are visible. preserve_all_rows : bool Whether to keep the full row list so row labels continue to identify all numeric column headers in tabular output. @@ -705,19 +665,12 @@ def _trim_correlation_display_dataframe( row_numbers = list(range(1, num_rows + 1)) col_numbers = list(range(1, num_cols + 1)) - if show_diagonal or triangle == 'full' or min(num_rows, num_cols) <= 1: + if min(num_rows, num_cols) <= 1: return corr_df, row_numbers, col_numbers - if triangle == 'lower': - if preserve_all_rows: - return corr_df.iloc[:, :-1], row_numbers, col_numbers[:-1] - return corr_df.iloc[1:, :-1], row_numbers[1:], col_numbers[:-1] - if triangle == 'upper': - if preserve_all_rows: - return corr_df.iloc[:, 1:], row_numbers, col_numbers[1:] - return corr_df.iloc[:-1, 1:], row_numbers[:-1], col_numbers[1:] - - return corr_df, row_numbers, col_numbers + if preserve_all_rows: + return corr_df.iloc[:, :-1], row_numbers, col_numbers[:-1] + return corr_df.iloc[1:, :-1], row_numbers[1:], col_numbers[:-1] def _get_param_correlation_dataframe(self) -> pd.DataFrame | None: """ @@ -915,7 +868,7 @@ def _plot_correlation_heatmap( precision: int, ) -> None: """ - Delegate correlation heatmap rendering to the Plotly backend. + Delegate correlation heatmap rendering to the backend. Parameters ---------- diff --git a/tests/unit/easydiffraction/display/test_plotting.py b/tests/unit/easydiffraction/display/test_plotting.py index 66b90f90d..9ff1c5a04 100644 --- a/tests/unit/easydiffraction/display/test_plotting.py +++ b/tests/unit/easydiffraction/display/test_plotting.py @@ -361,101 +361,6 @@ class Project: assert list(fig.data[1].text) == ['-0.91', '0.83', '-0.89', '0.82'] -def test_plot_param_correlations_can_show_diagonal(monkeypatch): - import numpy as np - - from easydiffraction.display.plotting import Plotter - from easydiffraction.display.tables import TableRenderer - - captured = {} - - class FakeTabler: - def render(self, df): - captured['df'] = df - - monkeypatch.setattr(TableRenderer, 'get', staticmethod(lambda: FakeTabler())) - - class Param: - def __init__(self, uid, unique_name): - self._minimizer_uid = uid - self.unique_name = unique_name - - class RawResult: - covar = np.array([[1.0, 0.4], [0.4, 1.0]]) - var_names = ['p1', 'p2'] - - class FitResults: - engine_result = RawResult() - parameters = [ - Param('p1', 'phase.scale'), - Param('p2', 'phase.cell.length_c'), - ] - - class Analysis: - fit_results = FitResults() - - class Project: - analysis = Analysis() - - p = Plotter() - p.engine = 'asciichartpy' - p._set_project(Project()) - p.plot_param_correlations(threshold=0.1, show_diagonal=True) - - df = captured['df'] - assert df.iloc[0, 1].strip() == '1.00' - assert df.iloc[0, 2] == '' - assert df.iloc[1, 2].strip() == '1.00' - - -def test_plot_param_correlations_can_show_full_matrix(monkeypatch): - import numpy as np - - from easydiffraction.display.plotting import Plotter - from easydiffraction.display.tables import TableRenderer - - captured = {} - - class FakeTabler: - def render(self, df): - captured['df'] = df - - monkeypatch.setattr(TableRenderer, 'get', staticmethod(lambda: FakeTabler())) - - class Param: - def __init__(self, uid, unique_name): - self._minimizer_uid = uid - self.unique_name = unique_name - - class RawResult: - covar = np.array([[1.0, 0.4], [0.4, 1.0]]) - var_names = ['p1', 'p2'] - - class FitResults: - engine_result = RawResult() - parameters = [ - Param('p1', 'phase.scale'), - Param('p2', 'phase.cell.length_c'), - ] - - class Analysis: - fit_results = FitResults() - - class Project: - analysis = Analysis() - - p = Plotter() - p.engine = 'asciichartpy' - p._set_project(Project()) - p.plot_param_correlations(threshold=0.1, triangle='full', show_diagonal=False) - - df = captured['df'] - assert df.iloc[0, 1] == '' - assert df.iloc[0, 2].strip() == '0.40' - assert df.iloc[1, 1].strip() == '0.40' - assert df.iloc[1, 2] == '' - - def test_plot_param_correlations_filters_by_default_threshold(monkeypatch): from easydiffraction.display.plotting import Plotter from easydiffraction.display.tables import TableRenderer @@ -599,13 +504,3 @@ def test_plot_param_correlations_threshold_validation(): with pytest.raises(ValueError, match='between 0 and 1'): Plotter._filter_correlation_dataframe(object(), threshold=1.1) - - -def test_plot_param_correlations_triangle_validation(): - import pandas as pd - - from easydiffraction.display.plotting import Plotter - - df = pd.DataFrame([[1.0]]) - with pytest.raises(ValueError, match='lower'): - Plotter._mask_correlation_triangle(df, triangle='sideways', show_diagonal=False) From 35891709ebf649c92e08d09126cb61262fe4e26a Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 11:18:57 +0200 Subject: [PATCH 08/10] Fix PlotterBase.plot_correlation_heatmap lint errors --- src/easydiffraction/display/plotters/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/easydiffraction/display/plotters/base.py b/src/easydiffraction/display/plotters/base.py index 2c0fb40c1..f6cf408bc 100644 --- a/src/easydiffraction/display/plotters/base.py +++ b/src/easydiffraction/display/plotters/base.py @@ -285,4 +285,7 @@ def plot_correlation_heatmap( precision : int Number of decimals to show in labels and hover text. """ - return + # Intentionally unused; accepted for API compatibility with + # graphical backends that override this method. + _ = self._supports_graphical_heatmap + del corr_df, title, threshold, precision From 9d05f89b5fed8ae7a55e74e3d7d4f5fb5a12f266 Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 11:28:57 +0200 Subject: [PATCH 09/10] Fix future annotations, dead code, and **kwargs in display --- .../architecture/sequential_fitting_design.md | 2 +- pr-review.md | 104 ++++++++++++++++++ src/easydiffraction/display/plotters/ascii.py | 2 + src/easydiffraction/display/plotters/base.py | 2 + .../display/plotters/plotly.py | 10 +- src/easydiffraction/display/plotting.py | 79 +------------ 6 files changed, 119 insertions(+), 80 deletions(-) create mode 100644 pr-review.md diff --git a/docs/architecture/sequential_fitting_design.md b/docs/architecture/sequential_fitting_design.md index 292020b59..e6219fd1a 100644 --- a/docs/architecture/sequential_fitting_design.md +++ b/docs/architecture/sequential_fitting_design.md @@ -1120,7 +1120,7 @@ propagation, diffrn callback, precondition validation. **Implemented:** `Plotter.plot_param_series()` resolves CSV vs snapshots automatically via the project reference. `Plotter._plot_param_series_from_csv()` reads CSV via pandas. -`Plotter._plot_param_series_from_snapshots()` preserves backward +`Plotter.plot_param_series_from_snapshots()` preserves backward compatibility for `fit()` single-mode (no CSV yet). Axis labels derived from live descriptor objects. diff --git a/pr-review.md b/pr-review.md new file mode 100644 index 000000000..aa9c63c78 --- /dev/null +++ b/pr-review.md @@ -0,0 +1,104 @@ +# PR Review: `param-correlations` → `develop` + +## Summary + +54 files changed, +1494/−118 lines. Adds correlation matrix plotting +(heatmap via Plotly, ASCII table via asciichartpy), data extraction from +fit results, triangle masking, threshold filtering, and comprehensive +tests. + +--- + +## Issues + +### Issue #1 (Medium) — `triangle` parameter uses raw strings instead of Enum + +**Rule**: §9.6 — every finite, closed set of values must use a +`(str, Enum)`. + +`plot_param_correlations(triangle='lower')` accepts raw strings +`'lower'`, `'upper'`, `'full'` and compares them with `==` inside +`_mask_correlation_triangle` and `_trim_correlation_display_dataframe`. + +**Status**: ✅ FIXED — `CorrelationTriangleEnum` was created. **Updated +status**: User decided to remove `show_diagonal` and `triangle` +parameters entirely, keeping only the default behavior (lower triangle, +no diagonal). This eliminates the need for the enum altogether. + +--- + +### Issue #2 (Low-Medium) — `plot_correlation_heatmap` not declared on `PlotterBase`; facade uses `isinstance` dispatch + +**Rule**: §9.4 — backends should share a common interface via the base +class. + +`plotting.py` used `isinstance(self._backend, PlotlyPlotter)` to decide +whether to call `plot_correlation_heatmap`. This couples the facade to a +concrete backend. + +**Fix**: + +- Add `_supports_graphical_heatmap: bool = False` on `PlotterBase`, + override to `True` on `PlotlyPlotter`. +- Add a concrete no-op `plot_correlation_heatmap()` on `PlotterBase`. +- Replace `isinstance` check with + `self._backend._supports_graphical_heatmap`. + +**Status**: ✅ FIXED — `_supports_graphical_heatmap` flag added, `isinstance` +replaced, no-op method on `PlotterBase` uses `del` pattern to consume +unused args (matching `PlotlyPlotter` convention). + +--- + +### Issue #3 (Low) — Missing `from __future__ import annotations` + +**Rule**: Code style — use `from __future__ import annotations` in every +module. + +Pre-existing in several files. Not in scope for this PR. + +**Status**: ⏭️ SKIPPED (pre-existing, out of scope) + +--- + +### Issue #4 (Medium) — `_plot_param_series_from_snapshots` is dead copy-paste + +Pre-existing dead code that duplicates `_plot_param_series_from_csv`. + +**Status**: ⏭️ SKIPPED (pre-existing, out of scope) + +--- + +### Issue #5 (Informational) — Coverage threshold lowered 75 → 70 + +Aligns with architecture.md §10 which states 70%. + +**Status**: ✅ OK + +--- + +### Issue #6 (Informational) — `_get_layout` uses `**kwargs` + +Pre-existing pattern. Not in scope for this PR. + +**Status**: ⏭️ SKIPPED (pre-existing, out of scope) + +--- + +## Completed Actions + +1. ✅ Remove `CorrelationTriangleEnum` from `plotting.py` +2. ✅ Remove `show_diagonal` and `triangle` params from + `plot_param_correlations` +3. ✅ Simplify `_mask_correlation_triangle` → hardcode lower triangle, no + diagonal +4. ✅ Simplify `_trim_correlation_display_dataframe` → hardcode lower + triangle trim +5. ✅ Fix `PlotterBase.plot_correlation_heatmap` lint errors (PLR6301, + ARG002) +6. ✅ Remove tests for diagonal/full/triangle validation +7. ✅ Update remaining tests +8. ✅ `pixi run fix` + `pixi run check` + `pixi run unit-tests` (683 + passed) +9. ✅ `pixi run integration-tests` (98 passed) + + `pixi run script-tests` (18 passed) diff --git a/src/easydiffraction/display/plotters/ascii.py b/src/easydiffraction/display/plotters/ascii.py index 4d4edd692..9ef5484b5 100644 --- a/src/easydiffraction/display/plotters/ascii.py +++ b/src/easydiffraction/display/plotters/ascii.py @@ -8,6 +8,8 @@ a consistent API with other plotters. """ +from __future__ import annotations + import asciichartpy import numpy as np diff --git a/src/easydiffraction/display/plotters/base.py b/src/easydiffraction/display/plotters/base.py index f6cf408bc..920fac810 100644 --- a/src/easydiffraction/display/plotters/base.py +++ b/src/easydiffraction/display/plotters/base.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: BSD-3-Clause """Abstract base and shared constants for plotting backends.""" +from __future__ import annotations + from abc import ABC from abc import abstractmethod from enum import StrEnum diff --git a/src/easydiffraction/display/plotters/plotly.py b/src/easydiffraction/display/plotters/plotly.py index 8dd9a3391..af09c437a 100644 --- a/src/easydiffraction/display/plotters/plotly.py +++ b/src/easydiffraction/display/plotters/plotly.py @@ -8,6 +8,8 @@ renderer may be used depending on configuration. """ +from __future__ import annotations + import darkdetect import numpy as np import plotly.graph_objects as go @@ -506,7 +508,7 @@ def _show_figure( def _get_layout( title: str, axes_labels: object, - **kwargs: object, + shapes: list | None = None, ) -> object: """ Create a Plotly layout configuration. @@ -517,8 +519,8 @@ def _get_layout( Figure title. axes_labels : object Pair of strings for the x and y titles. - **kwargs : object - Additional layout parameters (e.g., shapes). + shapes : list | None, default=None + Optional list of shape dicts to overlay on the plot. Returns ------- @@ -553,7 +555,7 @@ def _get_layout( 'mirror': True, 'zeroline': False, }, - **kwargs, + shapes=shapes, ) def plot_powder( diff --git a/src/easydiffraction/display/plotting.py b/src/easydiffraction/display/plotting.py index ec5f04584..fc0fd9196 100644 --- a/src/easydiffraction/display/plotting.py +++ b/src/easydiffraction/display/plotting.py @@ -7,6 +7,8 @@ consistent configuration surface and engine handling. """ +from __future__ import annotations + import pathlib from enum import StrEnum @@ -36,7 +38,7 @@ class PlotterEngineEnum(StrEnum): PLOTLY = 'plotly' @classmethod - def default(cls) -> 'PlotterEngineEnum': + def default(cls) -> PlotterEngineEnum: """Select default engine based on environment.""" if in_jupyter(): log.debug('Setting default plotting engine to Plotly for Jupyter') @@ -495,7 +497,7 @@ def plot_param_series( else: # Fallback: in-memory snapshots from fit() single mode versus_name = versus.name if versus is not None else None - self._plot_param_series_from_snapshots( + self.plot_param_series_from_snapshots( unique_name, versus_name, self._project.experiments, @@ -1240,79 +1242,6 @@ def _plot_param_series_from_csv( height=self.height, ) - def _plot_param_series_from_snapshots( - self, - csv_path: str, - unique_name: str, - param_descriptor: object, - versus_descriptor: object | None = None, - ) -> None: - """ - Plot a parameter's value across sequential fit results. - - Reads data from the CSV file at *csv_path*. The y-axis values - come from the column named *unique_name*, uncertainties from - ``{unique_name}.uncertainty``. When *versus_descriptor* is - provided, the x-axis uses the corresponding ``diffrn.{name}`` - column; otherwise the row index is used. - - Axis labels are derived from the live descriptor objects - (*param_descriptor* and *versus_descriptor*), which carry - ``.description`` and ``.units`` attributes. - - Parameters - ---------- - csv_path : str - Path to the ``results.csv`` file. - unique_name : str - Unique name of the parameter to plot (CSV column key). - param_descriptor : object - The live parameter descriptor (for axis label / units). - versus_descriptor : object | None, default=None - A diffrn descriptor whose ``.name`` maps to a - ``diffrn.{name}`` CSV column. ``None`` → use row index. - """ - df = pd.read_csv(csv_path) - - if unique_name not in df.columns: - log.warning( - f"Parameter '{unique_name}' not found in CSV columns. " - f'Available: {list(df.columns)}' - ) - return - - y = df[unique_name].astype(float).tolist() - uncert_col = f'{unique_name}.uncertainty' - sy = df[uncert_col].astype(float).tolist() if uncert_col in df.columns else [0.0] * len(y) - - # X-axis: diffrn column or row index - versus_name = versus_descriptor.name if versus_descriptor is not None else None - diffrn_col = f'diffrn.{versus_name}' if versus_name else None - - if diffrn_col and diffrn_col in df.columns: - x = pd.to_numeric(df[diffrn_col], errors='coerce').tolist() - x_label = getattr(versus_descriptor, 'description', None) or versus_name - if hasattr(versus_descriptor, 'units') and versus_descriptor.units: - x_label = f'{x_label} ({versus_descriptor.units})' - else: - x = list(range(1, len(y) + 1)) - x_label = 'Experiment No.' - - # Y-axis label from descriptor - param_units = getattr(param_descriptor, 'units', '') - y_label = f'Parameter value ({param_units})' if param_units else 'Parameter value' - - title = f"Parameter '{unique_name}' across fit results" - - self._backend.plot_scatter( - x=x, - y=y, - sy=sy, - axes_labels=[x_label, y_label], - title=title, - height=self.height, - ) - def plot_param_series_from_snapshots( self, unique_name: str, From 0531dbcf456ad29db0b39741e5ea9e66c4672b6e Mon Sep 17 00:00:00 2001 From: Andrew Sazonov Date: Thu, 9 Apr 2026 11:35:44 +0200 Subject: [PATCH 10/10] Clean up --- pr-review.md | 104 --------------------------------------------------- 1 file changed, 104 deletions(-) delete mode 100644 pr-review.md diff --git a/pr-review.md b/pr-review.md deleted file mode 100644 index aa9c63c78..000000000 --- a/pr-review.md +++ /dev/null @@ -1,104 +0,0 @@ -# PR Review: `param-correlations` → `develop` - -## Summary - -54 files changed, +1494/−118 lines. Adds correlation matrix plotting -(heatmap via Plotly, ASCII table via asciichartpy), data extraction from -fit results, triangle masking, threshold filtering, and comprehensive -tests. - ---- - -## Issues - -### Issue #1 (Medium) — `triangle` parameter uses raw strings instead of Enum - -**Rule**: §9.6 — every finite, closed set of values must use a -`(str, Enum)`. - -`plot_param_correlations(triangle='lower')` accepts raw strings -`'lower'`, `'upper'`, `'full'` and compares them with `==` inside -`_mask_correlation_triangle` and `_trim_correlation_display_dataframe`. - -**Status**: ✅ FIXED — `CorrelationTriangleEnum` was created. **Updated -status**: User decided to remove `show_diagonal` and `triangle` -parameters entirely, keeping only the default behavior (lower triangle, -no diagonal). This eliminates the need for the enum altogether. - ---- - -### Issue #2 (Low-Medium) — `plot_correlation_heatmap` not declared on `PlotterBase`; facade uses `isinstance` dispatch - -**Rule**: §9.4 — backends should share a common interface via the base -class. - -`plotting.py` used `isinstance(self._backend, PlotlyPlotter)` to decide -whether to call `plot_correlation_heatmap`. This couples the facade to a -concrete backend. - -**Fix**: - -- Add `_supports_graphical_heatmap: bool = False` on `PlotterBase`, - override to `True` on `PlotlyPlotter`. -- Add a concrete no-op `plot_correlation_heatmap()` on `PlotterBase`. -- Replace `isinstance` check with - `self._backend._supports_graphical_heatmap`. - -**Status**: ✅ FIXED — `_supports_graphical_heatmap` flag added, `isinstance` -replaced, no-op method on `PlotterBase` uses `del` pattern to consume -unused args (matching `PlotlyPlotter` convention). - ---- - -### Issue #3 (Low) — Missing `from __future__ import annotations` - -**Rule**: Code style — use `from __future__ import annotations` in every -module. - -Pre-existing in several files. Not in scope for this PR. - -**Status**: ⏭️ SKIPPED (pre-existing, out of scope) - ---- - -### Issue #4 (Medium) — `_plot_param_series_from_snapshots` is dead copy-paste - -Pre-existing dead code that duplicates `_plot_param_series_from_csv`. - -**Status**: ⏭️ SKIPPED (pre-existing, out of scope) - ---- - -### Issue #5 (Informational) — Coverage threshold lowered 75 → 70 - -Aligns with architecture.md §10 which states 70%. - -**Status**: ✅ OK - ---- - -### Issue #6 (Informational) — `_get_layout` uses `**kwargs` - -Pre-existing pattern. Not in scope for this PR. - -**Status**: ⏭️ SKIPPED (pre-existing, out of scope) - ---- - -## Completed Actions - -1. ✅ Remove `CorrelationTriangleEnum` from `plotting.py` -2. ✅ Remove `show_diagonal` and `triangle` params from - `plot_param_correlations` -3. ✅ Simplify `_mask_correlation_triangle` → hardcode lower triangle, no - diagonal -4. ✅ Simplify `_trim_correlation_display_dataframe` → hardcode lower - triangle trim -5. ✅ Fix `PlotterBase.plot_correlation_heatmap` lint errors (PLR6301, - ARG002) -6. ✅ Remove tests for diagonal/full/triangle validation -7. ✅ Update remaining tests -8. ✅ `pixi run fix` + `pixi run check` + `pixi run unit-tests` (683 - passed) -9. ✅ `pixi run integration-tests` (98 passed) + - `pixi run script-tests` (18 passed)