Skip to content

Commit 6eef37c

Browse files
0.27.12 importance plots
1 parent 576f225 commit 6eef37c

5 files changed

Lines changed: 316 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.27.11"
10+
version = "0.27.12"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/importance.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import pandas as pd
2+
from sklearn.ensemble import RandomForestRegressor
3+
from sklearn.inspection import permutation_importance
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
7+
8+
def generate_mdi(X, y, feature_names=None, random_state=42) -> pd.DataFrame:
9+
"""
10+
Generates a DataFrame with Gini importances from a RandomForestRegressor.
11+
12+
Notes:
13+
There are two limitations of impurity-based feature importances:
14+
- impurity-based importances are biased towards high cardinality features;
15+
- impurity-based importances are computed on training set statistics
16+
and therefore do not reflect the ability of feature to be useful to
17+
make predictions that generalize to the test set. Permutation
18+
importances can mitigate the last limitation, because ti can be computed on the
19+
test set.
20+
21+
Args:
22+
X (pd.DataFrame or np.ndarray): The feature set.
23+
y (pd.Series or np.ndarray): The target variable.
24+
feature_names (list, optional): List of feature names for labeling. Defaults to None.
25+
random_state (int, optional): Random state for the RandomForestRegressor. Defaults to 42.
26+
27+
Returns:
28+
pd.DataFrame: DataFrame with 'Feature' and 'Importance' columns.
29+
30+
Examples:
31+
>>> from spotpython.plot.importance import generate_mdi
32+
>>> import pandas as pd
33+
>>> from sklearn.datasets import make_regression
34+
>>> X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42)
35+
>>> X_df = pd.DataFrame(X)
36+
>>> y_series = pd.Series(y)
37+
>>> result = generate_mdi(X_df, y_series)
38+
>>> print(result)
39+
40+
"""
41+
# Convert X and y to pandas DataFrames if they are not already
42+
if not isinstance(X, pd.DataFrame):
43+
X = pd.DataFrame(X)
44+
if not isinstance(y, pd.Series):
45+
y = pd.Series(np.ravel(y)) # Use np.ravel instead of flatten
46+
47+
# Train a Random Forest Regressor
48+
rf = RandomForestRegressor(random_state=random_state)
49+
rf.fit(X, y)
50+
51+
# Get feature importances
52+
importances = rf.feature_importances_
53+
54+
# Create a DataFrame
55+
if feature_names is None:
56+
df_mdi = pd.DataFrame({"Feature": X.columns, "Importance": importances})
57+
else:
58+
df_mdi = pd.DataFrame({"Feature": feature_names, "Importance": importances})
59+
df_mdi = df_mdi.sort_values("Importance", ascending=False).reset_index(drop=True)
60+
61+
return df_mdi
62+
63+
64+
def generate_imp(X_train, X_test, y_train, y_test, random_state=42, n_repeats=10, use_test=True) -> permutation_importance:
65+
"""
66+
Generates permutation importances from a RandomForestRegressor.
67+
68+
Args:
69+
X_train (pd.DataFrame or np.ndarray): The training feature set.
70+
X_test (pd.DataFrame or np.ndarray): The test feature set.
71+
y_train (pd.Series or np.ndarray): The training target variable.
72+
y_test (pd.Series or np.ndarray): The test target variable.
73+
random_state (int, optional): Random state for the RandomForestRegressor. Defaults to 42.
74+
n_repeats (int, optional): Number of repeats for permutation importance. Defaults to 10.
75+
use_test (bool, optional): If True, computes permutation importance on the test set. If False, uses the training set. Defaults to True.
76+
77+
Returns:
78+
permutation_importance: Permutation importances object.
79+
80+
Examples:
81+
>>> from spotpython.plot.importance import generate_imp
82+
>>> import pandas as pd
83+
>>> from sklearn.datasets import make_regression
84+
>>> X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42)
85+
>>> X_train, X_test = X[:80], X[80:]
86+
>>> y_train, y_test = y[:80], y[80:]
87+
>>> X_train_df = pd.DataFrame(X_train)
88+
>>> X_test_df = pd.DataFrame(X_test)
89+
>>> y_train_series = pd.Series(y_train)
90+
>>> y_test_series = pd.Series(y_test)
91+
>>> perm_imp = generate_imp(X_train_df, X_test_df, y_train_series, y_test_series)
92+
>>> print(perm_imp)
93+
"""
94+
# Convert inputs to pandas DataFrames/Series if they are not already
95+
if not isinstance(X_train, pd.DataFrame):
96+
X_train = pd.DataFrame(X_train)
97+
if not isinstance(X_test, pd.DataFrame):
98+
X_test = pd.DataFrame(X_test)
99+
if not isinstance(y_train, pd.Series):
100+
y_train = pd.Series(np.ravel(y_train)) # Use np.ravel instead of flatten
101+
if not isinstance(y_test, pd.Series):
102+
y_test = pd.Series(np.ravel(y_test)) # Use np.ravel instead of flatten
103+
104+
# Train a Random Forest Regressor
105+
rf = RandomForestRegressor(random_state=random_state)
106+
rf.fit(X_train, y_train)
107+
108+
# Select the dataset for permutation importance
109+
X_eval = X_test if use_test else X_train
110+
y_eval = y_test if use_test else y_train
111+
112+
# Calculate permutation importances
113+
perm_imp = permutation_importance(rf, X_eval, y_eval, n_repeats=n_repeats, random_state=random_state)
114+
115+
return perm_imp
116+
117+
118+
def plot_importances(df_mdi, perm_imp, X_test, target_name=None, feature_names=None, k=10, show=True) -> None:
119+
"""
120+
Plots the impurity-based and permutation-based feature importances for a given classifier.
121+
122+
Args:
123+
df_mdi (pd.DataFrame):
124+
DataFrame with Gini importances.
125+
perm_imp (object):
126+
Permutation importances object.
127+
X_test (pd.DataFrame):
128+
The test feature set for permutation importance.
129+
target_name (str, optional):
130+
Name of the target variable for labeling. Defaults to None.
131+
feature_names (list, optional):
132+
List of feature names for labeling. Defaults to None.
133+
k (int, optional):
134+
Number of top features to display based on importance. Default is 10.
135+
show (bool, optional):
136+
If True, displays the plot immediately. Default is True.
137+
138+
Returns:
139+
None
140+
141+
Examples:
142+
>>> from spotpython.plot.importance import generate_mdi, generate_imp, plot_importances
143+
>>> import pandas as pd
144+
>>> from sklearn.datasets import make_regression
145+
>>> X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42)
146+
>>> X_train, X_test = X[:80], X[80:]
147+
>>> y_train, y_test = y[:80], y[80:]
148+
>>> X_train_df = pd.DataFrame(X_train)
149+
>>> X_test_df = pd.DataFrame(X_test)
150+
>>> y_train_series = pd.Series(y_train)
151+
>>> y_test_series = pd.Series(y_test)
152+
>>> df_mdi = generate_mdi(X_train_df, y_train_series)
153+
>>> perm_imp = generate_imp(X_train_df, X_test_df, y_train_series, y_test_series)
154+
>>> plot_importances(df_mdi, perm_imp, X_test_df)
155+
156+
"""
157+
158+
# Plot impurity-based importances for top-k features
159+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
160+
161+
sorted_mdi_importances = df_mdi.set_index("Feature")["Importance"]
162+
sorted_mdi_importances[:k].sort_values().plot.barh(ax=ax1)
163+
ax1.set_xlabel("Gini importance")
164+
if target_name:
165+
ax1.set_title(f"Impurity-based feature importances for target: {target_name}")
166+
else:
167+
ax1.set_title("Impurity-based feature importances")
168+
169+
# Ensure X_test is a DataFrame
170+
if not isinstance(X_test, pd.DataFrame):
171+
X_test = pd.DataFrame(X_test)
172+
173+
perm_sorted_idx = perm_imp.importances_mean.argsort()[-k:]
174+
if feature_names is not None:
175+
ax2.boxplot(perm_imp.importances[perm_sorted_idx].T, vert=False, labels=np.array(feature_names)[perm_sorted_idx])
176+
else:
177+
ax2.boxplot(perm_imp.importances[perm_sorted_idx].T, vert=False, labels=X_test.columns[perm_sorted_idx])
178+
ax2.axvline(x=0, color="k", linestyle="--")
179+
if target_name:
180+
ax2.set_xlabel(f"Decrease in mse for target: {target_name}")
181+
else:
182+
ax2.set_xlabel("Decrease in mse")
183+
ax2.set_title("Permutation-based feature importances")
184+
185+
# fig.suptitle("Impurity-based vs. permutation importances")
186+
fig.tight_layout()
187+
if show:
188+
plt.show()

test/test_importance_imp.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import pandas as pd
3+
import numpy as np
4+
from sklearn.datasets import make_regression
5+
from sklearn.utils import Bunch
6+
from spotpython.plot.importance import generate_imp
7+
8+
def test_generate_imp():
9+
# Generate synthetic regression data
10+
X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42)
11+
X_train, X_test = X[:80], X[80:]
12+
y_train, y_test = y[:80], y[80:]
13+
14+
# Convert to DataFrame/Series for testing
15+
feature_names = [f"Feature_{i}" for i in range(X.shape[1])]
16+
X_train_df = pd.DataFrame(X_train, columns=feature_names)
17+
X_test_df = pd.DataFrame(X_test, columns=feature_names)
18+
y_train_series = pd.Series(y_train)
19+
y_test_series = pd.Series(y_test)
20+
21+
# Test permutation importance on the test set (default behavior)
22+
perm_imp_test = generate_imp(X_train_df, X_test_df, y_train_series, y_test_series, use_test=True)
23+
assert isinstance(perm_imp_test, Bunch), "Output should be a Bunch object"
24+
assert perm_imp_test.importances_mean.shape[0] == X.shape[1], "Number of importances should match the number of features"
25+
assert np.all(perm_imp_test.importances_mean >= 0), "All importances should be non-negative"
26+
27+
# Test permutation importance on the training set
28+
perm_imp_train = generate_imp(X_train_df, X_test_df, y_train_series, y_test_series, use_test=False)
29+
assert isinstance(perm_imp_train, Bunch), "Output should be a Bunch object"
30+
assert perm_imp_train.importances_mean.shape[0] == X.shape[1], "Number of importances should match the number of features"
31+
assert np.all(perm_imp_train.importances_mean >= 0), "All importances should be non-negative"

test/test_importance_mdi.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import pandas as pd
3+
import numpy as np
4+
from sklearn.datasets import make_regression
5+
from spotpython.plot.importance import generate_mdi
6+
7+
def test_generate_mdi_with_dataframe():
8+
# Generate synthetic data
9+
X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42)
10+
feature_names = [f"Feature_{i}" for i in range(X.shape[1])]
11+
X_df = pd.DataFrame(X, columns=feature_names)
12+
y_series = pd.Series(y)
13+
14+
# Call the function
15+
result = generate_mdi(X_df, y_series)
16+
17+
# Assertions
18+
assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
19+
assert list(result.columns) == ["Feature", "Importance"], "DataFrame should have 'Feature' and 'Importance' columns"
20+
assert len(result) == X_df.shape[1], "Number of rows should match the number of features"
21+
assert result["Importance"].sum() > 0, "Feature importances should be greater than zero"
22+
23+
def test_generate_mdi_with_ndarray():
24+
# Generate synthetic data
25+
X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42)
26+
27+
# Call the function
28+
result = generate_mdi(X, y)
29+
30+
# Assertions
31+
assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
32+
assert list(result.columns) == ["Feature", "Importance"], "DataFrame should have 'Feature' and 'Importance' columns"
33+
assert len(result) == X.shape[1], "Number of rows should match the number of features"
34+
assert result["Importance"].sum() > 0, "Feature importances should be greater than zero"
35+
36+
def test_generate_mdi_with_custom_feature_names():
37+
# Generate synthetic data
38+
X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42)
39+
feature_names = [f"Custom_Feature_{i}" for i in range(X.shape[1])]
40+
X_df = pd.DataFrame(X)
41+
42+
# Call the function
43+
result = generate_mdi(X_df, y, feature_names=feature_names)
44+
45+
# Assertions
46+
assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
47+
assert list(result.columns) == ["Feature", "Importance"], "DataFrame should have 'Feature' and 'Importance' columns"
48+
assert len(result) == len(feature_names), "Number of rows should match the number of custom feature names"
49+
assert set(result["Feature"].values) == set(feature_names), "Feature names should match the custom feature names"
50+
assert result["Importance"].sum() > 0, "Feature importances should be greater than zero"

test/test_importance_plot.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
import pandas as pd
3+
import numpy as np
4+
from sklearn.ensemble import RandomForestRegressor
5+
from sklearn.inspection import permutation_importance
6+
from spotpython.plot.importance import plot_importances
7+
8+
@pytest.fixture
9+
def sample_data():
10+
# Generate sample data
11+
np.random.seed(42)
12+
X_train = pd.DataFrame(np.random.rand(100, 5), columns=[f"Feature_{i}" for i in range(5)])
13+
X_test = pd.DataFrame(np.random.rand(20, 5), columns=[f"Feature_{i}" for i in range(5)])
14+
y_train = pd.Series(np.random.rand(100))
15+
y_test = pd.Series(np.random.rand(20))
16+
return X_train, X_test, y_train, y_test
17+
18+
@pytest.fixture
19+
def mdi_importances(sample_data):
20+
# Generate MDI importances
21+
X_train, _, y_train, _ = sample_data
22+
rf = RandomForestRegressor(random_state=42)
23+
rf.fit(X_train, y_train)
24+
importances = rf.feature_importances_
25+
df_mdi = pd.DataFrame({"Feature": X_train.columns, "Importance": importances}).sort_values("Importance", ascending=False)
26+
return df_mdi
27+
28+
@pytest.fixture
29+
def perm_importances(sample_data):
30+
# Generate permutation importances
31+
X_train, X_test, y_train, y_test = sample_data
32+
rf = RandomForestRegressor(random_state=42)
33+
rf.fit(X_train, y_train)
34+
perm_imp = permutation_importance(rf, X_test, y_test, n_repeats=10, random_state=42)
35+
return perm_imp
36+
37+
def test_plot_importances(sample_data, mdi_importances, perm_importances):
38+
X_train, X_test, y_train, y_test = sample_data
39+
df_mdi = mdi_importances
40+
perm_imp = perm_importances
41+
42+
# Test if the function runs without errors
43+
try:
44+
plot_importances(df_mdi, perm_imp, X_test, target_name="Test Target", feature_names=X_train.columns, k=3, show=False)
45+
except Exception as e:
46+
pytest.fail(f"plot_importances raised an exception: {e}")

0 commit comments

Comments
 (0)