Skip to content

Commit 0e080b4

Browse files
v0.0.18 improved importance plots
1 parent 6751b28 commit 0e080b4

2 files changed

Lines changed: 30 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.0.17"
10+
version = "0.0.18"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/spot/spot.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -606,16 +606,33 @@ def plot_contour(self, i=0, j=1, min_z=None, max_z=None, show=True):
606606
#
607607
pylab.show()
608608

609-
def print_importance(self):
610-
if self.surrogate.n_theta > 1:
611-
theta = np.power(10, self.surrogate.theta)
612-
print("Importance relative to the most important parameter:")
613-
imp = 100 * theta / np.max(theta)
614-
if self.var_name is None:
615-
for i in range(len(imp)):
616-
print("x", i, ": ", imp[i])
617-
else:
618-
for i in range(len(imp)):
619-
print(self.var_name[i] + ": ", imp[i])
609+
610+
def print_importance(self, threshold=0.1, filename=None) -> None:
611+
"""Print importance of each parameter and plot it.
612+
Args:
613+
threshold (float): Only parameters with importance >= threshold are printed.
614+
filename (str): If not None, the plot is saved to the file.
615+
Returns:
616+
None
617+
"""
618+
if self.surrogate.n_theta > 1:
619+
theta = np.power(10, self.surrogate.theta)
620+
print("Importance relative to the most important parameter:")
621+
imp = 100 * theta / np.max(theta)
622+
imp = imp[imp >= threshold]
623+
if self.var_name is None:
624+
for i in range(len(imp)):
625+
print("x", i, ": ", imp[i])
626+
plt.bar(range(len(imp)), imp)
627+
plt.xticks(range(len(imp)), ["x" + str(i) for i in range(len(imp))])
620628
else:
621-
print("Importantance requires more than one theta values (n_theta>1).")
629+
var_name = [self.var_name[i] for i in range(len(imp)) if imp[i] >= threshold]
630+
for i in range(len(imp)):
631+
print(var_name[i] + ": ", imp[i])
632+
plt.bar(range(len(imp)), imp)
633+
plt.xticks(range(len(imp)), var_name)
634+
if filename is not None:
635+
plt.savefig(filename)
636+
plt.show()
637+
else:
638+
print("Importantance requires more than one theta values (n_theta>1).")

0 commit comments

Comments
 (0)