|
| 1 | +import tkinter as tk |
| 2 | +from tkinter import ttk |
| 3 | + |
| 4 | +from spotRiver.tuner.run import run_spot_river_experiment, compare_tuned_default, contour_plot, parallel_plot, importance_plot, progress_plot |
| 5 | + |
| 6 | +result = None |
| 7 | +fun_control = None |
| 8 | + |
| 9 | + |
| 10 | +def run_experiment(): |
| 11 | + global result, fun_control |
| 12 | + MAX_TIME = float(max_time_entry.get()) |
| 13 | + INIT_SIZE = int(init_size_entry.get()) |
| 14 | + PREFIX = prefix_entry.get() |
| 15 | + horizon = int(horizon_entry.get()) |
| 16 | + n_total = n_total_entry.get() |
| 17 | + if n_total == "None" or n_total == "All": |
| 18 | + n_total = None |
| 19 | + else: |
| 20 | + n_total = int(n_total) |
| 21 | + perc_train = float(perc_train_entry.get()) |
| 22 | + oml_grace_period = oml_grace_period_entry.get() |
| 23 | + if oml_grace_period == "None" or oml_grace_period == "n_train": |
| 24 | + oml_grace_period = None |
| 25 | + else: |
| 26 | + oml_grace_period = int(oml_grace_period) |
| 27 | + data_set = data_set_combo.get() |
| 28 | + prep_model = prep_model_combo.get() |
| 29 | + core_model = core_model_combo.get() |
| 30 | + |
| 31 | + result, fun_control = run_spot_river_experiment( |
| 32 | + MAX_TIME=MAX_TIME, |
| 33 | + INIT_SIZE=INIT_SIZE, |
| 34 | + PREFIX=PREFIX, |
| 35 | + horizon=horizon, |
| 36 | + n_total=n_total, |
| 37 | + perc_train=perc_train, |
| 38 | + oml_grace_period=oml_grace_period, |
| 39 | + data_set=data_set, |
| 40 | + prepmodel=prep_model, |
| 41 | + coremodel=core_model, |
| 42 | + log_level=20, |
| 43 | + ) |
| 44 | + |
| 45 | + |
| 46 | +def call_compare_tuned_default(): |
| 47 | + if result is not None and fun_control is not None: |
| 48 | + compare_tuned_default(result, fun_control) |
| 49 | + |
| 50 | + |
| 51 | +def call_parallel_plot(): |
| 52 | + if result is not None: |
| 53 | + parallel_plot(result) |
| 54 | + |
| 55 | + |
| 56 | +def call_contour_plot(): |
| 57 | + if result is not None: |
| 58 | + contour_plot(result) |
| 59 | + |
| 60 | + |
| 61 | +def call_importance_plot(): |
| 62 | + if result is not None: |
| 63 | + importance_plot(result) |
| 64 | + |
| 65 | + |
| 66 | +def call_progress_plot(): |
| 67 | + if result is not None: |
| 68 | + progress_plot(result) |
| 69 | + |
| 70 | + |
| 71 | +# Create the main application window |
| 72 | +app = tk.Tk() |
| 73 | +app.title("Spot River Hyperparameter Tuning GUI") |
| 74 | + |
| 75 | +# Create a notebook (tabbed interface) |
| 76 | +notebook = ttk.Notebook(app) |
| 77 | +# notebook.pack(fill='both', expand=True) |
| 78 | + |
| 79 | +# Create and pack entry fields for the "Run" tab |
| 80 | +run_tab = ttk.Frame(notebook) |
| 81 | +notebook.add(run_tab, text="Binary classification") |
| 82 | + |
| 83 | +# colummns 0+1: Data |
| 84 | + |
| 85 | +data_label = tk.Label(run_tab, text="Data options:") |
| 86 | +data_label.grid(row=0, column=0, sticky="W") |
| 87 | + |
| 88 | +data_set_label = tk.Label(run_tab, text="Select data_set:") |
| 89 | +data_set_label.grid(row=1, column=0, sticky="W") |
| 90 | +data_set_values = ["Bananas", "CreditCard", "Elec2", "Higgs", "HTTP", "MaliciousURL", "Phishing", "SMSSpam", "SMTP", "TREC07", "USER"] |
| 91 | +data_set_combo = ttk.Combobox(run_tab, values=data_set_values) |
| 92 | +data_set_combo.set("Phishing") # Default selection |
| 93 | +data_set_combo.grid(row=1, column=1) |
| 94 | + |
| 95 | + |
| 96 | +n_total_label = tk.Label(run_tab, text="n_total:") |
| 97 | +n_total_label.grid(row=2, column=0, sticky="W") |
| 98 | +n_total_entry = tk.Entry(run_tab) |
| 99 | +n_total_entry.insert(0, "All") |
| 100 | +n_total_entry.grid(row=2, column=1, sticky="W") |
| 101 | + |
| 102 | +perc_train_label = tk.Label(run_tab, text="perc_train:") |
| 103 | +perc_train_label.grid(row=3, column=0, sticky="W") |
| 104 | +perc_train_entry = tk.Entry(run_tab) |
| 105 | +perc_train_entry.insert(0, "0.60") |
| 106 | +perc_train_entry.grid(row=3, column=1, sticky="W") |
| 107 | + |
| 108 | + |
| 109 | +# colummns 2+3: Model |
| 110 | +model_label = tk.Label(run_tab, text="Model options:") |
| 111 | +model_label.grid(row=0, column=2, sticky="W") |
| 112 | + |
| 113 | +prep_model_label = tk.Label(run_tab, text="Select preprocessing model") |
| 114 | +prep_model_label.grid(row=1, column=2, sticky="W") |
| 115 | +prep_model_values = ["MinMaxScaler", "StandardScaler", "None"] |
| 116 | +prep_model_combo = ttk.Combobox(run_tab, values=prep_model_values) |
| 117 | +prep_model_combo.set("StandardScaler") # Default selection |
| 118 | +prep_model_combo.grid(row=1, column=3) |
| 119 | + |
| 120 | + |
| 121 | +core_model_label = tk.Label(run_tab, text="Select core model") |
| 122 | +core_model_label.grid(row=2, column=2, sticky="W") |
| 123 | +core_model_values = ["AMFClassifier", "HoeffdingAdaptiveTreeClassifier", "LogisticRegression"] |
| 124 | +core_model_combo = ttk.Combobox(run_tab, values=core_model_values) |
| 125 | +core_model_combo.set("LogisticRegression") # Default selection |
| 126 | +core_model_combo.grid(row=2, column=3) |
| 127 | + |
| 128 | + |
| 129 | +# columns 4+5: Experiment |
| 130 | +experiment_label = tk.Label(run_tab, text="Experiment options:") |
| 131 | +experiment_label.grid(row=0, column=4, sticky="W") |
| 132 | + |
| 133 | +max_time_label = tk.Label(run_tab, text="MAX_TIME:") |
| 134 | +max_time_label.grid(row=1, column=4, sticky="W") |
| 135 | +max_time_entry = tk.Entry(run_tab) |
| 136 | +max_time_entry.insert(0, "1") |
| 137 | +max_time_entry.grid(row=1, column=5) |
| 138 | + |
| 139 | +init_size_label = tk.Label(run_tab, text="INIT_SIZE:") |
| 140 | +init_size_label.grid(row=2, column=4, sticky="W") |
| 141 | +init_size_entry = tk.Entry(run_tab) |
| 142 | +init_size_entry.insert(0, "3") |
| 143 | +init_size_entry.grid(row=2, column=5) |
| 144 | + |
| 145 | +prefix_label = tk.Label(run_tab, text="PREFIX:") |
| 146 | +prefix_label.grid(row=3, column=4, sticky="W") |
| 147 | +prefix_entry = tk.Entry(run_tab) |
| 148 | +prefix_entry.insert(0, "00") |
| 149 | +prefix_entry.grid(row=3, column=5) |
| 150 | + |
| 151 | +horizon_label = tk.Label(run_tab, text="horizon:") |
| 152 | +horizon_label.grid(row=4, column=4, sticky="W") |
| 153 | +horizon_entry = tk.Entry(run_tab) |
| 154 | +horizon_entry.insert(0, "1") |
| 155 | +horizon_entry.grid(row=4, column=5) |
| 156 | + |
| 157 | +oml_grace_period_label = tk.Label(run_tab, text="oml_grace_period:") |
| 158 | +oml_grace_period_label.grid(row=5, column=4, sticky="W") |
| 159 | +oml_grace_period_entry = tk.Entry(run_tab) |
| 160 | +oml_grace_period_entry.insert(0, "n_train") |
| 161 | +oml_grace_period_entry.grid(row=5, column=5) |
| 162 | + |
| 163 | +# column 6: Run button |
| 164 | +run_button = ttk.Button(run_tab, text="Run Experiment", command=run_experiment) |
| 165 | +run_button.grid(row=7, column=6, columnspan=2, sticky="E") |
| 166 | + |
| 167 | +# Create and pack the "Regression" tab with a button to run the analysis |
| 168 | +regression_tab = ttk.Frame(notebook) |
| 169 | +notebook.add(regression_tab, text="Regression") |
| 170 | + |
| 171 | +# colummns 0+1: Data |
| 172 | + |
| 173 | +regression_data_label = tk.Label(regression_tab, text="Data options:") |
| 174 | +regression_data_label.grid(row=0, column=0, sticky="W") |
| 175 | + |
| 176 | +# colummns 2+3: Model |
| 177 | +regression_model_label = tk.Label(regression_tab, text="Model options:") |
| 178 | +regression_model_label.grid(row=0, column=2, sticky="W") |
| 179 | + |
| 180 | +# columns 4+5: Experiment |
| 181 | +regression_experiment_label = tk.Label(regression_tab, text="Experiment options:") |
| 182 | +regression_experiment_label.grid(row=0, column=4, sticky="W") |
| 183 | + |
| 184 | + |
| 185 | +# Create and pack the "Analysis" tab with a button to run the analysis |
| 186 | +analysis_tab = ttk.Frame(notebook) |
| 187 | +notebook.add(analysis_tab, text="Analysis") |
| 188 | + |
| 189 | +notebook.pack() |
| 190 | + |
| 191 | + |
| 192 | +# Add the Logo image in both tabs |
| 193 | +logo_image = tk.PhotoImage(file="images/spotlogo.png") |
| 194 | +logo_label = tk.Label(run_tab, image=logo_image) |
| 195 | +logo_label.grid(row=0, column=6, rowspan=1, columnspan=1) |
| 196 | + |
| 197 | +analysis_label = tk.Label(analysis_tab, text="Analysis options:") |
| 198 | +analysis_label.grid(row=0, column=1, sticky="W") |
| 199 | + |
| 200 | +progress_plot_button = ttk.Button(analysis_tab, text="Progress plot", command=call_progress_plot) |
| 201 | +progress_plot_button.grid(row=1, column=1, columnspan=2, sticky="W") |
| 202 | + |
| 203 | +compare_tuned_default_button = ttk.Button(analysis_tab, text="Compare tuned vs. default", command=call_compare_tuned_default) |
| 204 | +compare_tuned_default_button.grid(row=2, column=1, columnspan=2, sticky="W") |
| 205 | + |
| 206 | +importance_plot_button = ttk.Button(analysis_tab, text="Importance plot", command=call_importance_plot) |
| 207 | +importance_plot_button.grid(row=3, column=1, columnspan=2, sticky="W") |
| 208 | + |
| 209 | +contour_plot_button = ttk.Button(analysis_tab, text="Contour plot", command=call_contour_plot) |
| 210 | +contour_plot_button.grid(row=4, column=1, columnspan=2, sticky="W") |
| 211 | + |
| 212 | +parallel_plot_button = ttk.Button(analysis_tab, text="Parallel plot (Browser)", command=call_parallel_plot) |
| 213 | +parallel_plot_button.grid(row=5, column=1, columnspan=2, sticky="W") |
| 214 | + |
| 215 | + |
| 216 | +analysis_logo_label = tk.Label(analysis_tab, image=logo_image) |
| 217 | +analysis_logo_label.grid(row=0, column=6, rowspan=1, columnspan=1) |
| 218 | + |
| 219 | +regression_logo_label = tk.Label(regression_tab, image=logo_image) |
| 220 | +regression_logo_label.grid(row=0, column=6, rowspan=1, columnspan=1) |
| 221 | + |
| 222 | +# Run the mainloop |
| 223 | + |
| 224 | +app.mainloop() |
0 commit comments