Skip to content

Commit 714833b

Browse files
0.24.32
New: plot_coeff_vs_pvals_by_included
1 parent 8899dc5 commit 714833b

7 files changed

Lines changed: 451 additions & 183 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 22 additions & 182 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.24.31"
10+
version = "0.24.32"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
@@ -48,6 +48,7 @@ dependencies = [
4848
"scipy",
4949
"spotriver>=0.4.1",
5050
"seaborn",
51+
"statsmodels",
5152
"tabulate",
5253
"tensorboard",
5354
"torch",

src/spotpython/utils/stats.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
import numpy as np
33
from scipy.stats import norm, t
44
from numpy.linalg import pinv, inv, LinAlgError
5+
import copy
6+
import itertools
7+
import matplotlib.pyplot as plt
8+
import seaborn as sns
9+
from statsmodels.formula.api import ols
510

611

712
def cov_to_cor(covariance_matrix) -> np.ndarray:
@@ -199,3 +204,270 @@ def pairwise_semi_partial_correlation(x, y, z, method="pearson"):
199204
"gp": spcor_result["gp"],
200205
"method": method,
201206
}
207+
208+
209+
def get_all_vars_from_formula(formula) -> list:
210+
"""Utility function to extract variables from a formula.
211+
212+
Args:
213+
formula (str): A formula.
214+
215+
Returns:
216+
list: A list of variables.
217+
218+
Examples:
219+
>>> from spotpython.utils.stats import get_all_vars_from_formula
220+
get_all_vars_from_formula("y ~ x1 + x2")
221+
['y', 'x1', 'x2']
222+
get_all_vars_from_formula("y ~ ")
223+
['y']
224+
"""
225+
# Split the formula into the dependent and independent variables
226+
dependent, independent = formula.split("~")
227+
# Strip whitespace and split the independent variables by '+'
228+
independent_vars = independent.strip().split("+") if independent.strip() else []
229+
# Combine the dependent variable with the independent variables
230+
return [dependent.strip()] + [var.strip() for var in independent_vars]
231+
232+
233+
def fit_all_lm(basic, xlist, data, remove_na=True) -> dict:
234+
"""Fit a linear regression model for all possible combinations of independent variables.
235+
236+
Args:
237+
basic (str): The basic model formula.
238+
xlist (list): A list of independent variables.
239+
data (pandas.DataFrame): The data frame containing the variables.
240+
remove_na (bool): Whether to remove missing values from the data frame.
241+
242+
Returns:
243+
dict: A dictionary containing the estimated coefficients, confidence intervals,
244+
p-values, AIC values, sample size, and the basic model formula.
245+
246+
Examples:
247+
>>> from spotpython.utils.stats import fit_all_lm
248+
>>> import pandas as pd
249+
>>> data = pd.DataFrame({
250+
>>> 'y': [1, 2, 3],
251+
>>> 'x1': [4, 5, 6],
252+
>>> 'x2': [7, 8, 9]
253+
>>> })
254+
>>> fit_all_lm("y ~ x1", ["x2"], data)
255+
{'estimate': variables estimate conf_low conf_high p aic n
256+
0 basic 1.000000 1.000000 1.000000 0.0 0.000000 3
257+
1 x2 1.000000 1.000000 1.000000 0.0 0.000000 3}
258+
"""
259+
# Prepare the data frame
260+
data = copy.deepcopy(data)
261+
data = data[get_all_vars_from_formula(basic) + xlist]
262+
if remove_na:
263+
data = data.dropna()
264+
print(data.head())
265+
# basic model
266+
mod_0 = ols(basic, data=data).fit()
267+
p = mod_0.pvalues.iloc[1]
268+
print(f"p-values: {p}")
269+
estimate = mod_0.params.iloc[1]
270+
print(f"estimate: {estimate}")
271+
conf_int = mod_0.conf_int().iloc[1]
272+
print(f"conf_int: {conf_int}")
273+
aic_value = mod_0.aic
274+
print(f"aic: {aic_value}")
275+
n = len(mod_0.resid)
276+
df_0 = pd.DataFrame([["basic", estimate, conf_int[0], conf_int[1], p, aic_value, n]], columns=["variables", "estimate", "conf_low", "conf_high", "p", "aic", "n"])
277+
278+
# All combinations model
279+
comb_lst = list(itertools.chain.from_iterable(itertools.combinations(xlist, r) for r in range(1, len(xlist) + 1)))
280+
models = [ols(f"{basic} + {' + '.join(comb)}", data=data).fit() for comb in comb_lst]
281+
282+
df_list = []
283+
for i, model in enumerate(models):
284+
p = model.pvalues.iloc[1]
285+
estimate = model.params.iloc[1]
286+
conf_int = model.conf_int().iloc[1]
287+
aic_value = model.aic
288+
n = len(model.resid)
289+
comb_str = ", ".join(comb_lst[i])
290+
df_list.append([comb_str, estimate, conf_int[0], conf_int[1], p, aic_value, n])
291+
292+
df_coef = pd.DataFrame(df_list, columns=["variables", "estimate", "conf_low", "conf_high", "p", "aic", "n"])
293+
estimates = pd.concat([df_0, df_coef], ignore_index=True)
294+
return {"estimate": estimates, "xlist": xlist, "fun": "all_lm", "basic": basic, "family": "lm"}
295+
296+
297+
def plot_coeff_vs_pvals(data, xlabels=None, xlim=(0, 1), xlab="p-value", ylim=None, ylab=None, xscale_log=True, yscale_log=False, title=None, show=True) -> None:
298+
"""Plot the coefficient estimates from fit_all_lm against the corresponding p-values.
299+
300+
Args:
301+
data (dict):
302+
A dictionary containing the estimated coefficients, p-values, and other information.
303+
Generated by the fit_all_lm function.
304+
xlabels (list):
305+
A list of x-axis labels.
306+
xlim (tuple):
307+
A tuple of the x-axis limits.
308+
xlab (str):
309+
The x-axis label.
310+
ylim (tuple):
311+
A tuple of the y-axis limits.
312+
ylab (str):
313+
The y-axis label.
314+
xscale_log (bool):
315+
Whether to use a log scale on the x-axis.
316+
yscale_log (bool):
317+
Whether to use a log scale on the y-axis.
318+
title (str):
319+
The plot title.
320+
show (bool):
321+
Whether to display the plot.
322+
323+
Returns:
324+
None
325+
326+
Notes:
327+
* Based on the R package 'allestimates' by Zhiqiang Wang, see https://cran.r-project.org/package=allestimates
328+
329+
References:
330+
Wang, Z. (2007). Two Postestimation Commands for Assessing Confounding Effects in Epidemiological Studies. The Stata Journal, 7(2), 183-196. https://doi.org/10.1177/1536867X0700700203
331+
332+
Examples:
333+
>>> from spotpython.utils.stats import plot_coeff_vs_pvals, fit_all_lm
334+
>>> import pandas as pd
335+
>>> data = pd.DataFrame({
336+
>>> 'y': [1, 2, 3],
337+
>>> 'x1': [4, 5, 6],
338+
>>> 'x2': [7, 8, 9]
339+
>>> })
340+
>>> estimates = fit_all_lm("y ~ x1", ["x2"], data)
341+
>>> plot_coeff_vs_pvals(estimates)
342+
"""
343+
data = copy.deepcopy(data)
344+
if xlabels is None:
345+
xlabels = [0, 0.001, 0.01, 0.05, 0.2, 0.5, 1]
346+
xbreaks = np.power(xlabels, np.log(0.5) / np.log(0.05))
347+
348+
result_df = data["estimate"]
349+
if ylab is None:
350+
ylab = "Coefficient" if data["fun"] == "all_lm" else "Effect estimates"
351+
hline = 0 if data["fun"] == "all_lm" else 1
352+
353+
result_df["p_value"] = np.power(result_df["p"], np.log(0.5) / np.log(0.05))
354+
if ylim is None:
355+
maxv = max(result_df["estimate"].max(), abs(result_df["estimate"].min()))
356+
ylim = (-maxv, maxv) if data["fun"] == "all_lm" else (1 / maxv, maxv)
357+
358+
plt.figure(figsize=(10, 6))
359+
sns.scatterplot(data=result_df, x="p_value", y="estimate")
360+
if xscale_log:
361+
plt.xscale("log")
362+
if yscale_log:
363+
plt.yscale("log")
364+
plt.xticks(ticks=xbreaks, labels=xlabels)
365+
plt.axvline(x=0.5, linestyle="--")
366+
plt.axhline(y=hline, linestyle="--")
367+
plt.xlim(xlim)
368+
plt.ylim(ylim)
369+
plt.xlabel(xlab)
370+
plt.ylabel(ylab)
371+
if title:
372+
plt.title(title)
373+
plt.grid(True)
374+
if show:
375+
plt.show()
376+
377+
378+
def plot_coeff_vs_pvals_by_included(data, xlabels=None, xlim=(0, 1), xlab="P value", ylim=None, ylab=None, yscale_log=False, title=None, grid=True, ncol=2, show=True) -> None:
379+
"""
380+
Generates a panel of scatter plots with effect estimates of all possible models against p-values.
381+
Uses a dictionry generated by the fit_all_lm function.
382+
Each plot includes effect estimates from all models including a specific variable.
383+
384+
Args:
385+
data (dict): A dictionary, generated by the fit_all_lm function, containing the following keys:
386+
- estimate (pd.DataFrame): A DataFrame containing the estimates.
387+
- xlist (list): A list of variables.
388+
- fun (str): The function name.
389+
- family (str): The family of the model.
390+
xlabels (list): A list of x-axis labels.
391+
xlim (tuple): The x-axis limits.
392+
xlab (str): The x-axis label.
393+
ylim (tuple): The y-axis limits.
394+
ylab (str): The y-axis label.
395+
yscale_log (bool): Whether to scale y-axis to log10. Default is False.
396+
title (str): The title of the plot.
397+
grid (bool): Whether to display gridlines. Default is True.
398+
ncol (int): Number of columns in the plot grid. Default is 2.
399+
400+
Returns:
401+
None
402+
403+
Notes:
404+
* Based on the R package 'allestimates' by Zhiqiang Wang, see https://cran.r-project.org/package=allestimates
405+
406+
References:
407+
Wang, Z. (2007). Two Postestimation Commands for Assessing Confounding Effects in Epidemiological Studies. The Stata Journal, 7(2), 183-196. https://doi.org/10.1177/1536867X0700700203
408+
409+
410+
Examples:
411+
data = {
412+
"estimate": pd.DataFrame({
413+
"variables": ["Crude", "AL", "AM", "AN", "AO"],
414+
"estimate": [0.5, 0.6, 0.7, 0.8, 0.9],
415+
"conf_low": [0.1, 0.2, 0.3, 0.4, 0.5],
416+
"conf_high": [0.9, 1.0, 1.1, 1.2, 1.3],
417+
"p": [0.01, 0.02, 0.03, 0.04, 0.05],
418+
"aic": [100, 200, 300, 400, 500],
419+
"n": [10, 20, 30, 40, 50]
420+
}),
421+
"xlist": ["AL", "AM", "AN", "AO"],
422+
"fun": "all_lm"
423+
}
424+
plot_coeff_vs_pvals_by_included(data)
425+
"""
426+
if xlabels is None:
427+
xlabels = [0, 0.001, 0.01, 0.05, 0.2, 0.5, 1]
428+
xbreaks = np.power(xlabels, np.log(0.5) / np.log(0.05))
429+
430+
result_df = data["estimate"]
431+
if ylab is None:
432+
ylab = {"all_lm": "Coefficient", "poisson": "Rate ratio", "binomial": "Odds ratio"}.get(data.get("fun"), "Effect estimates")
433+
434+
hline = 0 if data["fun"] == "all_lm" else 1
435+
436+
result_df["p_value"] = np.power(result_df["p"], np.log(0.5) / np.log(0.05))
437+
if ylim is None:
438+
maxv = max(result_df["estimate"].max(), abs(result_df["estimate"].min()))
439+
if data["fun"] == "all_lm":
440+
ylim = (-maxv, maxv)
441+
else:
442+
ylim = (1 / maxv, maxv)
443+
444+
# Create a DataFrame to mark inclusion of variables
445+
mark_df = pd.DataFrame({x: result_df["variables"].str.contains(x).astype(int) for x in data["xlist"]})
446+
df_scatter = pd.concat([result_df, mark_df], axis=1)
447+
448+
# Melt the DataFrame for plotting
449+
df_long = df_scatter.melt(id_vars=["variables", "estimate", "conf_low", "conf_high", "p", "aic", "n", "p_value"], value_vars=data["xlist"], var_name="variable", value_name="inclusion")
450+
df_long["inclusion"] = df_long["inclusion"].apply(lambda x: "Included" if x > 0 else "Not included")
451+
452+
# Plotting
453+
g = sns.FacetGrid(df_long, col="variable", hue="inclusion", palette={"Included": "blue", "Not included": "orange"}, col_wrap=ncol, height=4, sharex=False, sharey=False)
454+
g.map(sns.scatterplot, "p_value", "estimate")
455+
g.add_legend()
456+
for ax in g.axes.flat:
457+
ax.set_xticks(xbreaks)
458+
ax.set_xticklabels(xlabels)
459+
ax.set_xlim(xlim)
460+
ax.set_ylim(ylim)
461+
ax.axvline(x=0.5, linestyle="--", linewidth=1.5, color="black") # Black dashed vertical line
462+
ax.axhline(y=hline, linestyle="--", linewidth=1.5, color="black") # Black dashed horizontal line
463+
if grid:
464+
ax.grid(True)
465+
if yscale_log:
466+
g.set(yscale="log")
467+
g.set_axis_labels(xlab, ylab)
468+
g.set_titles("{col_name}")
469+
if title:
470+
plt.subplots_adjust(top=0.9)
471+
g.figure.suptitle(title)
472+
if show:
473+
plt.show()

test/test_get_all_lm.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
import pandas as pd
3+
from spotpython.utils.stats import fit_all_lm
4+
5+
def test_fit_all_lm():
6+
# Test case 1: Basic model with one independent variable
7+
data = pd.DataFrame({
8+
'y': [1, 2, 3],
9+
'x1': [4, 5, 6],
10+
'x2': [7, 8, 9]
11+
})
12+
result = fit_all_lm("y ~ x1", ["x2"], data)
13+
expected_vars = ['basic', 'x2']
14+
assert list(result['estimate']['variables']) == expected_vars
15+
assert result['fun'] == 'all_lm'
16+
assert result['basic'] == 'y ~ x1'
17+
assert result['family'] == 'lm'
18+
19+
# Test case 2: Model with multiple independent variables
20+
data = pd.DataFrame({
21+
'y': [1, 2, 3, 4],
22+
'x1': [4, 5, 6, 7],
23+
'x2': [7, 8, 9, 10],
24+
'x3': [10, 11, 12, 13]
25+
})
26+
result = fit_all_lm("y ~ x1", ["x2", "x3"], data)
27+
expected_vars = ['basic', 'x2', 'x3', 'x2, x3']
28+
assert list(result['estimate']['variables']) == expected_vars
29+
assert result['fun'] == 'all_lm'
30+
assert result['basic'] == 'y ~ x1'
31+
assert result['family'] == 'lm'
32+
33+
# Test case 3: Model with missing values
34+
data = pd.DataFrame({
35+
'y': [1, 2, None, 4],
36+
'x1': [4, 5, 6, 7],
37+
'x2': [7, 8, 9, 10]
38+
})
39+
result = fit_all_lm("y ~ x1", ["x2"], data, remove_na=True)
40+
expected_vars = ['basic', 'x2']
41+
assert list(result['estimate']['variables']) == expected_vars
42+
assert result['fun'] == 'all_lm'
43+
assert result['basic'] == 'y ~ x1'
44+
assert result['family'] == 'lm'
45+
assert result['estimate']['n'].iloc[0] == 3 # Check if missing values were removed
46+
47+
if __name__ == "__main__":
48+
pytest.main()

test/test_get_vars_from_formula.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from spotpython.utils.stats import get_all_vars_from_formula
3+
4+
def test_get_all_vars_from_formula():
5+
# Test case 1: Simple formula
6+
formula = "y ~ x1 + x2"
7+
expected_vars = ['y', 'x1', 'x2']
8+
assert get_all_vars_from_formula(formula) == expected_vars
9+
10+
# Test case 2: Formula with extra spaces
11+
formula = " y ~ x1 + x2 "
12+
expected_vars = ['y', 'x1', 'x2']
13+
assert get_all_vars_from_formula(formula) == expected_vars
14+
15+
# Test case 3: Formula with multiple independent variables
16+
formula = "y ~ x1 + x2 + x3 + x4"
17+
expected_vars = ['y', 'x1', 'x2', 'x3', 'x4']
18+
assert get_all_vars_from_formula(formula) == expected_vars
19+
20+
# Test case 4: Formula with no independent variables
21+
formula = "y ~ "
22+
expected_vars = ['y']
23+
assert get_all_vars_from_formula(formula) == expected_vars
24+
25+
# Test case 5: Formula with only one independent variable
26+
formula = "y ~ x1"
27+
expected_vars = ['y', 'x1']
28+
assert get_all_vars_from_formula(formula) == expected_vars
29+
30+
if __name__ == "__main__":
31+
pytest.main()

0 commit comments

Comments
 (0)