Skip to content

Commit 36f966a

Browse files
0.13.2
Updated plot_important_hyperparameter_contour function
1 parent 0151585 commit 36f966a

4 files changed

Lines changed: 101 additions & 43 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 42 additions & 23 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.13.1"
10+
version = "0.13.2"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/spot/spot.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import logging
2929
import time
3030
from spotPython.utils.progress import progress_bar
31-
from spotPython.utils.convert import find_indices
31+
from spotPython.utils.convert import find_indices, sort_by_kth_and_return_indices
3232
from spotPython.hyperparameters.values import (
3333
get_control_key_value,
3434
get_ith_hyperparameter_name_from_fun_control,
@@ -1640,7 +1640,7 @@ def plot_contour(
16401640
pylab.show()
16411641

16421642
def plot_important_hyperparameter_contour(
1643-
self, threshold=0.0, filename=None, show=True, max_imp=None, title=""
1643+
self, threshold=0.0, filename=None, show=True, max_imp=None, title="", scale_global=False
16441644
) -> None:
16451645
"""
16461646
Plot the contour of important hyperparameters.
@@ -1649,7 +1649,7 @@ def plot_important_hyperparameter_contour(
16491649
16501650
Args:
16511651
threshold (float):
1652-
threshold for the importance
1652+
threshold for the importance. Not used any more in spotPython >= 0.13.2.
16531653
filename (str):
16541654
filename of the plot
16551655
show (bool):
@@ -1691,23 +1691,22 @@ def plot_important_hyperparameter_contour(
16911691
S.plot_important_hyperparameter_contour()
16921692
16931693
"""
1694-
impo_org = self.print_importance(threshold=threshold, print_screen=True)
1695-
print(f"impo: {impo_org}")
1696-
try:
1697-
impo = sorted(impo_org, key=lambda x: x[1], reverse=True)
1698-
except ValueError as e:
1699-
print(f"ValueError: {e}")
1700-
impo = impo_org
1701-
# if there are more than imp_max variables, select only the most important ones:
1694+
impo = self.print_importance(threshold=threshold, print_screen=True)
1695+
print(f"impo: {impo}")
1696+
indices = sort_by_kth_and_return_indices(array=impo, k=1)
1697+
print(f"indices: {indices}")
1698+
# take the first max_imp values from the indices array
17021699
if max_imp is not None:
1703-
if len(impo) > max_imp:
1704-
impo = impo[:max_imp]
1705-
print(f"impo after select: {impo}")
1706-
var_plots = [i for i, x in enumerate(impo) if x[1] > threshold]
1707-
min_z = min(self.y)
1708-
max_z = max(self.y)
1709-
for i in var_plots:
1710-
for j in var_plots:
1700+
indices = indices[:max_imp]
1701+
print(f"indices after max_imp selection: {indices}")
1702+
if scale_global:
1703+
min_z = min(self.y)
1704+
max_z = max(self.y)
1705+
else:
1706+
min_z = None
1707+
max_z = None
1708+
for i in indices:
1709+
for j in indices:
17111710
if j > i:
17121711
if filename is not None:
17131712
filename_full = filename + "_contour_" + str(i) + "_" + str(j) + ".png"

src/spotPython/utils/convert.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,43 @@ def map_to_True_False(value):
139139
return True
140140
else:
141141
return False
142+
143+
144+
def sort_by_kth_and_return_indices(array, k):
145+
"""Sorts an array of arrays based on the k-th values in descending order and returns
146+
the indices of the original array entries.
147+
148+
Args:
149+
array (list of lists): The array to be sorted. Each sub-array should have at least
150+
`k+1` elements.
151+
k (int): The index (zero-based) of the element within each sub-array to sort by.
152+
153+
Returns:
154+
list of int: Indices of the original array entries after sorting by the k-th value.
155+
156+
Raises:
157+
ValueError: If the input array is empty, None, or any sub-array does not have at least
158+
`k+1` elements, or if k is out of bounds for any sub-array.
159+
160+
Examples:
161+
>>> from spotPython.utils.convert import sort_by_kth_and_return_indices
162+
try:
163+
array = [['x0', 85.50983192204619], ['x1', 100.0], ['x2', 81.35712613549178]]
164+
k = 1 # Sort by the second element in each sub-array
165+
indices = sort_by_kth_and_return_indices(array, k)
166+
print("Indices of the sorted elements using the k-th value:", indices)
167+
except ValueError as error:
168+
print(f"Sorting failed due to: {error}")
169+
"""
170+
if not array:
171+
return []
172+
173+
# Check for improperly structured sub-arrays and that k is within bounds
174+
for item in array:
175+
if not isinstance(item, list) or len(item) <= k:
176+
raise ValueError("All sub-arrays must be lists with at least k+1 elements.")
177+
178+
# Enumerate the array to keep track of original indices, then sort by the k-th item
179+
sorted_indices = [index for index, value in sorted(enumerate(array), key=lambda x: x[1][k], reverse=True)]
180+
181+
return sorted_indices

0 commit comments

Comments
 (0)