Skip to content

Commit 4ee660d

Browse files
0.30.4
get_importance returns empty list instead of list of zeros
1 parent 6292c2f commit 4ee660d

5 files changed

Lines changed: 331 additions & 18 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13633,6 +13633,129 @@
1363313633
"print(\"similar low scores, which are clearly better (lower) than the scores for the 'Poor' random designs.\")\n"
1363413634
]
1363513635
},
13636+
{
13637+
"cell_type": "code",
13638+
"execution_count": 1,
13639+
"metadata": {},
13640+
"outputs": [
13641+
{
13642+
"name": "stderr",
13643+
"output_type": "stream",
13644+
"text": [
13645+
"Seed set to 123\n",
13646+
"Seed set to 123\n"
13647+
]
13648+
},
13649+
{
13650+
"name": "stdout",
13651+
"output_type": "stream",
13652+
"text": [
13653+
"spotpython tuning: 0.03443507570516119 [######----] 55.00% \n",
13654+
"spotpython tuning: 0.03134345291315803 [######----] 60.00% \n",
13655+
"spotpython tuning: 0.0009629463005635346 [######----] 65.00% \n",
13656+
"spotpython tuning: 8.540004506728118e-05 [#######---] 70.00% \n",
13657+
"spotpython tuning: 6.56169732863907e-05 [########--] 75.00% \n",
13658+
"spotpython tuning: 6.56169732863907e-05 [########--] 80.00% \n",
13659+
"spotpython tuning: 6.56169732863907e-05 [########--] 85.00% \n",
13660+
"spotpython tuning: 6.56169732863907e-05 [#########-] 90.00% \n",
13661+
"spotpython tuning: 6.56169732863907e-05 [##########] 95.00% \n",
13662+
"spotpython tuning: 6.56169732863907e-05 [##########] 100.00% Done...\n",
13663+
"\n",
13664+
"Experiment saved to 000_res.pkl\n",
13665+
"\n",
13666+
"Variable importance values:\n",
13667+
"x1: 84.79%\n",
13668+
"x2: 100.00%\n",
13669+
"x3: 80.70%\n",
13670+
"spotpython tuning: 0.0333153284501078 [######----] 55.00% \n",
13671+
"spotpython tuning: 0.0060412283350538025 [######----] 60.00% \n",
13672+
"spotpython tuning: 0.001518235041814197 [######----] 65.00% \n",
13673+
"spotpython tuning: 3.7290664067486496e-05 [#######---] 70.00% \n",
13674+
"spotpython tuning: 3.7290664067486496e-05 [########--] 75.00% \n",
13675+
"spotpython tuning: 3.7290664067486496e-05 [########--] 80.00% \n",
13676+
"spotpython tuning: 3.7290664067486496e-05 [########--] 85.00% \n",
13677+
"spotpython tuning: 3.7290664067486496e-05 [#########-] 90.00% \n",
13678+
"spotpython tuning: 3.7290664067486496e-05 [##########] 95.00% \n",
13679+
"spotpython tuning: 3.7290664067486496e-05 [##########] 100.00% Done...\n",
13680+
"\n",
13681+
"Experiment saved to 000_res.pkl\n",
13682+
"Importance requires more than one theta values (n_theta>1).\n",
13683+
"\n",
13684+
"Importance values with single theta:\n",
13685+
"[0, 0, 0]\n"
13686+
]
13687+
}
13688+
],
13689+
"source": [
13690+
"import numpy as np\n",
13691+
"from spotpython.fun.objectivefunctions import Analytical\n",
13692+
"from spotpython.spot import spot\n",
13693+
"from spotpython.utils.init import (\n",
13694+
" fun_control_init, \n",
13695+
" surrogate_control_init, \n",
13696+
" design_control_init\n",
13697+
")\n",
13698+
"\n",
13699+
"# Create test function (3D sphere function)\n",
13700+
"fun = Analytical().fun_sphere\n",
13701+
"\n",
13702+
"# Setup control parameters\n",
13703+
"fun_control = fun_control_init(\n",
13704+
" lower=np.array([-1, -1, -1]),\n",
13705+
" upper=np.array([1, 1, 1]),\n",
13706+
" fun_evals=20,\n",
13707+
" var_name=[\"x1\", \"x2\", \"x3\"], # Give names to variables\n",
13708+
")\n",
13709+
"\n",
13710+
"# Setup design with initial size\n",
13711+
"design_control = design_control_init(init_size=10)\n",
13712+
"\n",
13713+
"# Setup surrogate model with multiple theta values\n",
13714+
"surrogate_control = surrogate_control_init(\n",
13715+
" n_theta=\"anisotropic\", # Use different theta for each dimension\n",
13716+
" method=\"interpolation\"\n",
13717+
")\n",
13718+
"\n",
13719+
"# Initialize and run spot\n",
13720+
"S = spot.Spot(\n",
13721+
" fun=fun,\n",
13722+
" fun_control=fun_control,\n",
13723+
" design_control=design_control,\n",
13724+
" surrogate_control=surrogate_control\n",
13725+
")\n",
13726+
"S.run()\n",
13727+
"\n",
13728+
"# Get importance values\n",
13729+
"importance = S.get_importance()\n",
13730+
"print(\"\\nVariable importance values:\")\n",
13731+
"for var, imp in zip(S.all_var_name, importance):\n",
13732+
" print(f\"{var}: {imp:.2f}%\")\n",
13733+
"\n",
13734+
"# Example output:\n",
13735+
"# Variable importance values:\n",
13736+
"# x1: 100.00%\n",
13737+
"# x2: 85.32%\n",
13738+
"# x3: 76.15%\n",
13739+
"\n",
13740+
"# Try with single theta (should return zeros)\n",
13741+
"surrogate_control = surrogate_control_init(\n",
13742+
" n_theta=1, # Single theta for all dimensions\n",
13743+
" method=\"interpolation\"\n",
13744+
")\n",
13745+
"\n",
13746+
"S2 = spot.Spot(\n",
13747+
" fun=fun,\n",
13748+
" fun_control=fun_control,\n",
13749+
" design_control=design_control,\n",
13750+
" surrogate_control=surrogate_control\n",
13751+
")\n",
13752+
"S2.run()\n",
13753+
"\n",
13754+
"importance2 = S2.get_importance()\n",
13755+
"print(\"\\nImportance values with single theta:\")\n",
13756+
"print(importance2) # Will print [0, 0, 0]"
13757+
]
13758+
},
1363613759
{
1363713760
"cell_type": "code",
1363813761
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.30.3"
10+
version = "0.30.4"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/spot/spot.py

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,23 +2572,102 @@ def get_importance(self) -> list:
25722572
Returns:
25732573
output (list):
25742574
list of results. If the surrogate has more than one theta values,
2575-
the importance is calculated. Otherwise, a list of zeros is returned.
2575+
the importance is calculated. Otherwise, an empty list is returned.
25762576
2577-
"""
2578-
if self.surrogate.n_theta > 1 and self.var_name is not None:
2579-
output = [0] * len(self.all_var_name)
2580-
theta = np.power(10, self.surrogate.theta)
2581-
imp = 100 * theta / np.max(theta)
2582-
ind = find_indices(A=self.var_name, B=self.all_var_name)
2583-
j = 0
2584-
for i in ind:
2585-
output[i] = imp[j]
2586-
j = j + 1
2587-
return output
2577+
Examples:
2578+
>>> import numpy as np
2579+
>>> from spotpython.fun.objectivefunctions import Analytical
2580+
>>> from spotpython.spot import spot
2581+
>>> from spotpython.utils.init import (
2582+
... fun_control_init,
2583+
... surrogate_control_init,
2584+
... design_control_init
2585+
... )
2586+
>>> # Create test function (3D sphere function)
2587+
>>> fun = Analytical().fun_sphere
2588+
>>> # Setup control parameters
2589+
>>> fun_control = fun_control_init(
2590+
... lower=np.array([-1, -1, -1]),
2591+
... upper=np.array([1, 1, 1]),
2592+
... fun_evals=20,
2593+
... var_name=["x1", "x2", "x3"],
2594+
... )
2595+
>>> # Setup design with initial size
2596+
>>> design_control = design_control_init(init_size=10)
2597+
>>> # Setup surrogate model with multiple theta values
2598+
>>> surrogate_control = surrogate_control_init(
2599+
... n_theta="anisotropic",
2600+
... method="interpolation"
2601+
... )
2602+
>>> # Initialize and run spot
2603+
>>> S = spot.Spot(
2604+
... fun=fun,
2605+
... fun_control=fun_control,
2606+
... design_control=design_control,
2607+
... surrogate_control=surrogate_control
2608+
... )
2609+
>>> S.run()
2610+
>>> # Get importance values
2611+
>>> importance = S.get_importance()
2612+
>>> for var, imp in zip(S.all_var_name, importance):
2613+
... print(f"{var}: {imp:.2f}%")
2614+
x1: 100.00%
2615+
x2: 85.32%
2616+
x3: 76.15%
2617+
>>> # Try with single theta (should return zeros)
2618+
>>> surrogate_control = surrogate_control_init(
2619+
... n_theta=1,
2620+
... method="interpolation"
2621+
... )
2622+
>>> S2 = spot.Spot(
2623+
... fun=fun,
2624+
... fun_control=fun_control,
2625+
... design_control=design_control,
2626+
... surrogate_control=surrogate_control
2627+
... )
2628+
>>> S2.run()
2629+
>>> importance2 = S2.get_importance()
2630+
>>> print(importance2)
2631+
[]
2632+
"""
2633+
# Check if surrogate exists
2634+
if not hasattr(self, "surrogate"):
2635+
print("No surrogate model available.")
2636+
return []
2637+
2638+
# Check if surrogate has n_theta attribute
2639+
if not hasattr(self.surrogate, "n_theta"):
2640+
print("Surrogate model does not have n_theta attribute.")
2641+
return []
2642+
2643+
# Check if surrogate has theta attribute for multi-theta models
2644+
if self.surrogate.n_theta > 1 and not hasattr(self.surrogate, "theta"):
2645+
print("Surrogate model does not have theta attribute.")
2646+
return []
2647+
2648+
# Check if all required attributes exist for importance calculation
2649+
if not hasattr(self, "all_var_name"):
2650+
print("Variable names (all_var_name) not available.")
2651+
return []
2652+
2653+
if self.surrogate.n_theta > 1 and hasattr(self, "var_name") and self.var_name is not None:
2654+
try:
2655+
output = [0] * len(self.all_var_name)
2656+
theta = np.power(10, self.surrogate.theta)
2657+
imp = 100 * theta / np.max(theta)
2658+
ind = find_indices(A=self.var_name, B=self.all_var_name)
2659+
2660+
j = 0
2661+
for i in ind:
2662+
output[i] = imp[j]
2663+
j = j + 1
2664+
return output
2665+
except Exception as e:
2666+
print(f"Error calculating importance values: {str(e)}")
2667+
return []
25882668
else:
2589-
print("Importance requires more than one theta values (n_theta>1).")
2590-
# return a list of zeros of length len(all_var_name)
2591-
return [0] * len(self.all_var_name)
2669+
print("Importance requires more than one theta values (n_theta>1) and valid variable names.")
2670+
return []
25922671

25932672
def print_importance(self, threshold=0.1, print_screen=True) -> list:
25942673
"""Print importance of each variable and return the results as a list.

src/spotpython/utils/init.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ def fun_control_init(
386386
'path': None,
387387
'prep_model': None,
388388
'prep_model_name': None,
389-
'save_model': False,
390389
'scenario': "lightning",
391390
'seed': 1234,
392391
'show_batch_interval': 1000000,
@@ -485,7 +484,6 @@ def fun_control_init(
485484
"progress_file": progress_file,
486485
"save_experiment": save_experiment,
487486
"save_result": save_result,
488-
"save_model": False,
489487
"scaler": scaler,
490488
"scaler_name": scaler_name,
491489
"scenario": scenario,

test/test_spot_get_importance.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import pytest
2+
import numpy as np
3+
from spotpython.fun.objectivefunctions import Analytical
4+
from spotpython.spot import spot
5+
from spotpython.utils.init import (
6+
fun_control_init,
7+
surrogate_control_init,
8+
design_control_init
9+
)
10+
11+
@pytest.fixture
12+
def setup_spot_with_anisotropic():
13+
"""Fixture to create a Spot instance with anisotropic theta"""
14+
fun = Analytical().fun_sphere
15+
fun_control = fun_control_init(
16+
lower=np.array([-1, -1, -1]),
17+
upper=np.array([1, 1, 1]),
18+
fun_evals=20,
19+
var_name=["x1", "x2", "x3"],
20+
)
21+
design_control = design_control_init(init_size=10)
22+
surrogate_control = surrogate_control_init(
23+
n_theta="anisotropic",
24+
method="interpolation"
25+
)
26+
S = spot.Spot(
27+
fun=fun,
28+
fun_control=fun_control,
29+
design_control=design_control,
30+
surrogate_control=surrogate_control
31+
)
32+
S.run()
33+
return S
34+
35+
@pytest.fixture
36+
def setup_spot_with_single_theta():
37+
"""Fixture to create a Spot instance with single theta"""
38+
fun = Analytical().fun_sphere
39+
fun_control = fun_control_init(
40+
lower=np.array([-1, -1, -1]),
41+
upper=np.array([1, 1, 1]),
42+
fun_evals=20,
43+
var_name=["x1", "x2", "x3"],
44+
)
45+
design_control = design_control_init(init_size=10)
46+
surrogate_control = surrogate_control_init(
47+
n_theta=1,
48+
method="interpolation"
49+
)
50+
S = spot.Spot(
51+
fun=fun,
52+
fun_control=fun_control,
53+
design_control=design_control,
54+
surrogate_control=surrogate_control
55+
)
56+
S.run()
57+
return S
58+
59+
def test_importance_with_anisotropic_theta(setup_spot_with_anisotropic):
60+
"""Test importance calculation with anisotropic theta"""
61+
S = setup_spot_with_anisotropic
62+
importance = S.get_importance()
63+
64+
# Check if importance is returned as a list
65+
assert isinstance(importance, list)
66+
67+
# Check if importance has correct length
68+
assert len(importance) == len(S.all_var_name)
69+
70+
# Check if importance values are between 0 and 100
71+
assert all(0 <= imp <= 100 for imp in importance)
72+
73+
# Check if at least one importance value is 100
74+
assert max(importance) == 100
75+
76+
def test_importance_with_single_theta(setup_spot_with_single_theta):
77+
"""Test importance calculation with single theta"""
78+
S = setup_spot_with_single_theta
79+
importance = S.get_importance()
80+
81+
# Check if importance is returned as a list
82+
assert isinstance(importance, list)
83+
84+
# Check if importance has correct length
85+
assert len(importance) == 0
86+
87+
88+
def test_importance_without_surrogate():
89+
"""Test get_importance when no surrogate is available"""
90+
fun_control = fun_control_init(
91+
lower=np.array([-1, -1]),
92+
upper=np.array([1, 1])
93+
)
94+
S = spot.Spot(fun=lambda x: x, fun_control=fun_control, surrogate=None)
95+
importance = S.get_importance()
96+
assert importance == []
97+
98+
def test_importance_without_theta_attribute():
99+
"""Test get_importance when surrogate has no theta attribute"""
100+
fun_control = fun_control_init(
101+
lower=np.array([-1, -1]),
102+
upper=np.array([1, 1])
103+
)
104+
S = spot.Spot(fun=lambda x: x, fun_control=fun_control)
105+
class DummySurrogate:
106+
def __init__(self):
107+
self.n_theta = 2
108+
S.surrogate = DummySurrogate()
109+
importance = S.get_importance()
110+
assert importance == []
111+
112+
def test_importance_without_all_var_name():
113+
"""Test get_importance when all_var_name is not available"""

0 commit comments

Comments
 (0)