Skip to content

Commit 941cd41

Browse files
v0.6.26
show options added
1 parent 2c58e68 commit 941cd41

2 files changed

Lines changed: 17 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.25"
10+
version = "0.6.26"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/spot/spot.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def plot_contour(
849849
if show:
850850
pylab.show()
851851

852-
def plot_important_hyperparameter_contour(self, threshold=0.025, filename=None):
852+
def plot_important_hyperparameter_contour(self, threshold=0.025, filename=None, show=True) -> None:
853853
impo = self.print_importance(threshold=threshold, print_screen=True)
854854
var_plots = [i for i, x in enumerate(impo) if x[1] > threshold]
855855
min_z = min(self.y)
@@ -861,7 +861,7 @@ def plot_important_hyperparameter_contour(self, threshold=0.025, filename=None):
861861
filename_full = filename + "_contour_" + str(i) + "_" + str(j) + ".png"
862862
else:
863863
filename_full = None
864-
self.plot_contour(i=i, j=j, min_z=min_z, max_z=max_z, filename=filename_full)
864+
self.plot_contour(i=i, j=j, min_z=min_z, max_z=max_z, filename=filename_full, show=show)
865865

866866
def get_importance(self) -> list:
867867
"""Get importance of each variable and return the results as a list.
@@ -934,7 +934,17 @@ def plot_importance(self, threshold=0.1, filename=None, dpi=300) -> None:
934934
plt.savefig(filename, bbox_inches="tight", dpi=dpi)
935935
plt.show()
936936

937-
def parallel_plot(self):
937+
def parallel_plot(self, show=True) -> go.Figure:
938+
"""
939+
Parallel plot.
940+
941+
Args:
942+
show (bool): show the plot
943+
944+
Returns:
945+
fig (plotly.graph_objects.Figure): figure object
946+
947+
"""
938948
X = self.X
939949
y = self.y
940950
df = pd.DataFrame(np.concatenate((X, y.reshape(-1, 1)), axis=1), columns=self.var_name + ["y"])
@@ -950,4 +960,6 @@ def parallel_plot(self):
950960
),
951961
)
952962
)
953-
fig.show()
963+
if show:
964+
fig.show()
965+
return fig

0 commit comments

Comments
 (0)