Skip to content

Commit cfcf579

Browse files
0.30.2
1 parent d94f2a6 commit cfcf579

2 files changed

Lines changed: 12 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.30.1"
10+
version = "0.30.2"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/spot/spot.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2625,7 +2625,7 @@ def print_importance(self, threshold=0.1, print_screen=True) -> list:
26252625
print("Importance requires more than one theta values (n_theta>1).")
26262626
return output
26272627

2628-
def plot_importance(self, threshold=0.1, filename=None, dpi=300, show=True, tkagg=False) -> None:
2628+
def plot_importance(self, threshold=0.1, filename=None, dpi=300, show=True, tkagg=False, figsize=(9, 6)) -> None:
26292629
"""Plot the importance of each variable.
26302630
26312631
Args:
@@ -2637,23 +2637,33 @@ def plot_importance(self, threshold=0.1, filename=None, dpi=300, show=True, tkag
26372637
The dpi of the plot.
26382638
show (bool):
26392639
Show the plot. Default is `True`.
2640+
tkagg (bool):
2641+
Use TkAgg backend. Default is `False`.
2642+
figsize (tuple):
2643+
Figure size (width, height) in inches. Default is (9, 6).
26402644
26412645
Returns:
26422646
None
26432647
"""
26442648
if self.surrogate.n_theta > 1:
26452649
if tkagg:
26462650
matplotlib.use("TkAgg")
2651+
2652+
# Create figure with specified size
2653+
plt.figure(figsize=figsize)
2654+
26472655
theta = np.power(10, self.surrogate.theta)
26482656
imp = 100 * theta / np.max(theta)
26492657
idx = np.where(imp > threshold)[0]
2658+
26502659
if self.var_name is None:
26512660
plt.bar(range(len(imp[idx])), imp[idx])
26522661
plt.xticks(range(len(imp[idx])), ["x" + str(i) for i in idx])
26532662
else:
26542663
var_name = [self.var_name[i] for i in idx]
26552664
plt.bar(range(len(imp[idx])), imp[idx])
26562665
plt.xticks(range(len(imp[idx])), var_name)
2666+
26572667
if filename is not None:
26582668
plt.savefig(filename, bbox_inches="tight", dpi=dpi)
26592669
if show:

0 commit comments

Comments
 (0)