|
2 | 2 | import numpy as np |
3 | 3 | from scipy.stats import norm, t |
4 | 4 | from numpy.linalg import pinv, inv, LinAlgError |
| 5 | +import copy |
| 6 | +import itertools |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import seaborn as sns |
| 9 | +from statsmodels.formula.api import ols |
5 | 10 |
|
6 | 11 |
|
7 | 12 | def cov_to_cor(covariance_matrix) -> np.ndarray: |
@@ -199,3 +204,270 @@ def pairwise_semi_partial_correlation(x, y, z, method="pearson"): |
199 | 204 | "gp": spcor_result["gp"], |
200 | 205 | "method": method, |
201 | 206 | } |
| 207 | + |
| 208 | + |
| 209 | +def get_all_vars_from_formula(formula) -> list: |
| 210 | + """Utility function to extract variables from a formula. |
| 211 | +
|
| 212 | + Args: |
| 213 | + formula (str): A formula. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + list: A list of variables. |
| 217 | +
|
| 218 | + Examples: |
| 219 | + >>> from spotpython.utils.stats import get_all_vars_from_formula |
| 220 | + get_all_vars_from_formula("y ~ x1 + x2") |
| 221 | + ['y', 'x1', 'x2'] |
| 222 | + get_all_vars_from_formula("y ~ ") |
| 223 | + ['y'] |
| 224 | + """ |
| 225 | + # Split the formula into the dependent and independent variables |
| 226 | + dependent, independent = formula.split("~") |
| 227 | + # Strip whitespace and split the independent variables by '+' |
| 228 | + independent_vars = independent.strip().split("+") if independent.strip() else [] |
| 229 | + # Combine the dependent variable with the independent variables |
| 230 | + return [dependent.strip()] + [var.strip() for var in independent_vars] |
| 231 | + |
| 232 | + |
| 233 | +def fit_all_lm(basic, xlist, data, remove_na=True) -> dict: |
| 234 | + """Fit a linear regression model for all possible combinations of independent variables. |
| 235 | +
|
| 236 | + Args: |
| 237 | + basic (str): The basic model formula. |
| 238 | + xlist (list): A list of independent variables. |
| 239 | + data (pandas.DataFrame): The data frame containing the variables. |
| 240 | + remove_na (bool): Whether to remove missing values from the data frame. |
| 241 | +
|
| 242 | + Returns: |
| 243 | + dict: A dictionary containing the estimated coefficients, confidence intervals, |
| 244 | + p-values, AIC values, sample size, and the basic model formula. |
| 245 | +
|
| 246 | + Examples: |
| 247 | + >>> from spotpython.utils.stats import fit_all_lm |
| 248 | + >>> import pandas as pd |
| 249 | + >>> data = pd.DataFrame({ |
| 250 | + >>> 'y': [1, 2, 3], |
| 251 | + >>> 'x1': [4, 5, 6], |
| 252 | + >>> 'x2': [7, 8, 9] |
| 253 | + >>> }) |
| 254 | + >>> fit_all_lm("y ~ x1", ["x2"], data) |
| 255 | + {'estimate': variables estimate conf_low conf_high p aic n |
| 256 | + 0 basic 1.000000 1.000000 1.000000 0.0 0.000000 3 |
| 257 | + 1 x2 1.000000 1.000000 1.000000 0.0 0.000000 3} |
| 258 | + """ |
| 259 | + # Prepare the data frame |
| 260 | + data = copy.deepcopy(data) |
| 261 | + data = data[get_all_vars_from_formula(basic) + xlist] |
| 262 | + if remove_na: |
| 263 | + data = data.dropna() |
| 264 | + print(data.head()) |
| 265 | + # basic model |
| 266 | + mod_0 = ols(basic, data=data).fit() |
| 267 | + p = mod_0.pvalues.iloc[1] |
| 268 | + print(f"p-values: {p}") |
| 269 | + estimate = mod_0.params.iloc[1] |
| 270 | + print(f"estimate: {estimate}") |
| 271 | + conf_int = mod_0.conf_int().iloc[1] |
| 272 | + print(f"conf_int: {conf_int}") |
| 273 | + aic_value = mod_0.aic |
| 274 | + print(f"aic: {aic_value}") |
| 275 | + n = len(mod_0.resid) |
| 276 | + df_0 = pd.DataFrame([["basic", estimate, conf_int[0], conf_int[1], p, aic_value, n]], columns=["variables", "estimate", "conf_low", "conf_high", "p", "aic", "n"]) |
| 277 | + |
| 278 | + # All combinations model |
| 279 | + comb_lst = list(itertools.chain.from_iterable(itertools.combinations(xlist, r) for r in range(1, len(xlist) + 1))) |
| 280 | + models = [ols(f"{basic} + {' + '.join(comb)}", data=data).fit() for comb in comb_lst] |
| 281 | + |
| 282 | + df_list = [] |
| 283 | + for i, model in enumerate(models): |
| 284 | + p = model.pvalues.iloc[1] |
| 285 | + estimate = model.params.iloc[1] |
| 286 | + conf_int = model.conf_int().iloc[1] |
| 287 | + aic_value = model.aic |
| 288 | + n = len(model.resid) |
| 289 | + comb_str = ", ".join(comb_lst[i]) |
| 290 | + df_list.append([comb_str, estimate, conf_int[0], conf_int[1], p, aic_value, n]) |
| 291 | + |
| 292 | + df_coef = pd.DataFrame(df_list, columns=["variables", "estimate", "conf_low", "conf_high", "p", "aic", "n"]) |
| 293 | + estimates = pd.concat([df_0, df_coef], ignore_index=True) |
| 294 | + return {"estimate": estimates, "xlist": xlist, "fun": "all_lm", "basic": basic, "family": "lm"} |
| 295 | + |
| 296 | + |
| 297 | +def plot_coeff_vs_pvals(data, xlabels=None, xlim=(0, 1), xlab="p-value", ylim=None, ylab=None, xscale_log=True, yscale_log=False, title=None, show=True) -> None: |
| 298 | + """Plot the coefficient estimates from fit_all_lm against the corresponding p-values. |
| 299 | +
|
| 300 | + Args: |
| 301 | + data (dict): |
| 302 | + A dictionary containing the estimated coefficients, p-values, and other information. |
| 303 | + Generated by the fit_all_lm function. |
| 304 | + xlabels (list): |
| 305 | + A list of x-axis labels. |
| 306 | + xlim (tuple): |
| 307 | + A tuple of the x-axis limits. |
| 308 | + xlab (str): |
| 309 | + The x-axis label. |
| 310 | + ylim (tuple): |
| 311 | + A tuple of the y-axis limits. |
| 312 | + ylab (str): |
| 313 | + The y-axis label. |
| 314 | + xscale_log (bool): |
| 315 | + Whether to use a log scale on the x-axis. |
| 316 | + yscale_log (bool): |
| 317 | + Whether to use a log scale on the y-axis. |
| 318 | + title (str): |
| 319 | + The plot title. |
| 320 | + show (bool): |
| 321 | + Whether to display the plot. |
| 322 | +
|
| 323 | + Returns: |
| 324 | + None |
| 325 | +
|
| 326 | + Notes: |
| 327 | + * Based on the R package 'allestimates' by Zhiqiang Wang, see https://cran.r-project.org/package=allestimates |
| 328 | +
|
| 329 | + References: |
| 330 | + Wang, Z. (2007). Two Postestimation Commands for Assessing Confounding Effects in Epidemiological Studies. The Stata Journal, 7(2), 183-196. https://doi.org/10.1177/1536867X0700700203 |
| 331 | +
|
| 332 | + Examples: |
| 333 | + >>> from spotpython.utils.stats import plot_coeff_vs_pvals, fit_all_lm |
| 334 | + >>> import pandas as pd |
| 335 | + >>> data = pd.DataFrame({ |
| 336 | + >>> 'y': [1, 2, 3], |
| 337 | + >>> 'x1': [4, 5, 6], |
| 338 | + >>> 'x2': [7, 8, 9] |
| 339 | + >>> }) |
| 340 | + >>> estimates = fit_all_lm("y ~ x1", ["x2"], data) |
| 341 | + >>> plot_coeff_vs_pvals(estimates) |
| 342 | + """ |
| 343 | + data = copy.deepcopy(data) |
| 344 | + if xlabels is None: |
| 345 | + xlabels = [0, 0.001, 0.01, 0.05, 0.2, 0.5, 1] |
| 346 | + xbreaks = np.power(xlabels, np.log(0.5) / np.log(0.05)) |
| 347 | + |
| 348 | + result_df = data["estimate"] |
| 349 | + if ylab is None: |
| 350 | + ylab = "Coefficient" if data["fun"] == "all_lm" else "Effect estimates" |
| 351 | + hline = 0 if data["fun"] == "all_lm" else 1 |
| 352 | + |
| 353 | + result_df["p_value"] = np.power(result_df["p"], np.log(0.5) / np.log(0.05)) |
| 354 | + if ylim is None: |
| 355 | + maxv = max(result_df["estimate"].max(), abs(result_df["estimate"].min())) |
| 356 | + ylim = (-maxv, maxv) if data["fun"] == "all_lm" else (1 / maxv, maxv) |
| 357 | + |
| 358 | + plt.figure(figsize=(10, 6)) |
| 359 | + sns.scatterplot(data=result_df, x="p_value", y="estimate") |
| 360 | + if xscale_log: |
| 361 | + plt.xscale("log") |
| 362 | + if yscale_log: |
| 363 | + plt.yscale("log") |
| 364 | + plt.xticks(ticks=xbreaks, labels=xlabels) |
| 365 | + plt.axvline(x=0.5, linestyle="--") |
| 366 | + plt.axhline(y=hline, linestyle="--") |
| 367 | + plt.xlim(xlim) |
| 368 | + plt.ylim(ylim) |
| 369 | + plt.xlabel(xlab) |
| 370 | + plt.ylabel(ylab) |
| 371 | + if title: |
| 372 | + plt.title(title) |
| 373 | + plt.grid(True) |
| 374 | + if show: |
| 375 | + plt.show() |
| 376 | + |
| 377 | + |
| 378 | +def plot_coeff_vs_pvals_by_included(data, xlabels=None, xlim=(0, 1), xlab="P value", ylim=None, ylab=None, yscale_log=False, title=None, grid=True, ncol=2, show=True) -> None: |
| 379 | + """ |
| 380 | + Generates a panel of scatter plots with effect estimates of all possible models against p-values. |
| 381 | + Uses a dictionry generated by the fit_all_lm function. |
| 382 | + Each plot includes effect estimates from all models including a specific variable. |
| 383 | +
|
| 384 | + Args: |
| 385 | + data (dict): A dictionary, generated by the fit_all_lm function, containing the following keys: |
| 386 | + - estimate (pd.DataFrame): A DataFrame containing the estimates. |
| 387 | + - xlist (list): A list of variables. |
| 388 | + - fun (str): The function name. |
| 389 | + - family (str): The family of the model. |
| 390 | + xlabels (list): A list of x-axis labels. |
| 391 | + xlim (tuple): The x-axis limits. |
| 392 | + xlab (str): The x-axis label. |
| 393 | + ylim (tuple): The y-axis limits. |
| 394 | + ylab (str): The y-axis label. |
| 395 | + yscale_log (bool): Whether to scale y-axis to log10. Default is False. |
| 396 | + title (str): The title of the plot. |
| 397 | + grid (bool): Whether to display gridlines. Default is True. |
| 398 | + ncol (int): Number of columns in the plot grid. Default is 2. |
| 399 | +
|
| 400 | + Returns: |
| 401 | + None |
| 402 | +
|
| 403 | + Notes: |
| 404 | + * Based on the R package 'allestimates' by Zhiqiang Wang, see https://cran.r-project.org/package=allestimates |
| 405 | +
|
| 406 | + References: |
| 407 | + Wang, Z. (2007). Two Postestimation Commands for Assessing Confounding Effects in Epidemiological Studies. The Stata Journal, 7(2), 183-196. https://doi.org/10.1177/1536867X0700700203 |
| 408 | +
|
| 409 | +
|
| 410 | + Examples: |
| 411 | + data = { |
| 412 | + "estimate": pd.DataFrame({ |
| 413 | + "variables": ["Crude", "AL", "AM", "AN", "AO"], |
| 414 | + "estimate": [0.5, 0.6, 0.7, 0.8, 0.9], |
| 415 | + "conf_low": [0.1, 0.2, 0.3, 0.4, 0.5], |
| 416 | + "conf_high": [0.9, 1.0, 1.1, 1.2, 1.3], |
| 417 | + "p": [0.01, 0.02, 0.03, 0.04, 0.05], |
| 418 | + "aic": [100, 200, 300, 400, 500], |
| 419 | + "n": [10, 20, 30, 40, 50] |
| 420 | + }), |
| 421 | + "xlist": ["AL", "AM", "AN", "AO"], |
| 422 | + "fun": "all_lm" |
| 423 | + } |
| 424 | + plot_coeff_vs_pvals_by_included(data) |
| 425 | + """ |
| 426 | + if xlabels is None: |
| 427 | + xlabels = [0, 0.001, 0.01, 0.05, 0.2, 0.5, 1] |
| 428 | + xbreaks = np.power(xlabels, np.log(0.5) / np.log(0.05)) |
| 429 | + |
| 430 | + result_df = data["estimate"] |
| 431 | + if ylab is None: |
| 432 | + ylab = {"all_lm": "Coefficient", "poisson": "Rate ratio", "binomial": "Odds ratio"}.get(data.get("fun"), "Effect estimates") |
| 433 | + |
| 434 | + hline = 0 if data["fun"] == "all_lm" else 1 |
| 435 | + |
| 436 | + result_df["p_value"] = np.power(result_df["p"], np.log(0.5) / np.log(0.05)) |
| 437 | + if ylim is None: |
| 438 | + maxv = max(result_df["estimate"].max(), abs(result_df["estimate"].min())) |
| 439 | + if data["fun"] == "all_lm": |
| 440 | + ylim = (-maxv, maxv) |
| 441 | + else: |
| 442 | + ylim = (1 / maxv, maxv) |
| 443 | + |
| 444 | + # Create a DataFrame to mark inclusion of variables |
| 445 | + mark_df = pd.DataFrame({x: result_df["variables"].str.contains(x).astype(int) for x in data["xlist"]}) |
| 446 | + df_scatter = pd.concat([result_df, mark_df], axis=1) |
| 447 | + |
| 448 | + # Melt the DataFrame for plotting |
| 449 | + df_long = df_scatter.melt(id_vars=["variables", "estimate", "conf_low", "conf_high", "p", "aic", "n", "p_value"], value_vars=data["xlist"], var_name="variable", value_name="inclusion") |
| 450 | + df_long["inclusion"] = df_long["inclusion"].apply(lambda x: "Included" if x > 0 else "Not included") |
| 451 | + |
| 452 | + # Plotting |
| 453 | + g = sns.FacetGrid(df_long, col="variable", hue="inclusion", palette={"Included": "blue", "Not included": "orange"}, col_wrap=ncol, height=4, sharex=False, sharey=False) |
| 454 | + g.map(sns.scatterplot, "p_value", "estimate") |
| 455 | + g.add_legend() |
| 456 | + for ax in g.axes.flat: |
| 457 | + ax.set_xticks(xbreaks) |
| 458 | + ax.set_xticklabels(xlabels) |
| 459 | + ax.set_xlim(xlim) |
| 460 | + ax.set_ylim(ylim) |
| 461 | + ax.axvline(x=0.5, linestyle="--", linewidth=1.5, color="black") # Black dashed vertical line |
| 462 | + ax.axhline(y=hline, linestyle="--", linewidth=1.5, color="black") # Black dashed horizontal line |
| 463 | + if grid: |
| 464 | + ax.grid(True) |
| 465 | + if yscale_log: |
| 466 | + g.set(yscale="log") |
| 467 | + g.set_axis_labels(xlab, ylab) |
| 468 | + g.set_titles("{col_name}") |
| 469 | + if title: |
| 470 | + plt.subplots_adjust(top=0.9) |
| 471 | + g.figure.suptitle(title) |
| 472 | + if show: |
| 473 | + plt.show() |
0 commit comments