|
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
19 | | - "MAX_TIME = 5\n", |
20 | | - "INIT_SIZE = 10\n", |
| 19 | + "MAX_TIME = 10\n", |
| 20 | + "INIT_SIZE = 50\n", |
21 | 21 | "CLASSIFICATION = True\n", |
22 | 22 | "REGRESSION = False\n", |
23 | 23 | "MOONS = True\n", |
|
32 | 32 | { |
33 | 33 | "data": { |
34 | 34 | "text/plain": [ |
35 | | - "'10-sklearn_p040025_5min_10init_2023-05-08_23-24-38'" |
| 35 | + "'10-sklearn_p040025_10min_50init_2023-05-09_00-04-46'" |
36 | 36 | ] |
37 | 37 | }, |
38 | 38 | "execution_count": 2, |
|
58 | 58 | "metadata": {}, |
59 | 59 | "source": [ |
60 | 60 | "# Chapter 10: Sequential Parameter Optimization\n", |
61 | | - "## Hyperparameter Tuning: sklearn decision tree" |
| 61 | + "## Hyperparameter Tuning: sklearn" |
62 | 62 | ] |
63 | 63 | }, |
64 | 64 | { |
|
130 | 130 | "from sklearn.ensemble import HistGradientBoostingRegressor\n", |
131 | 131 | "from sklearn.model_selection import cross_validate\n", |
132 | 132 | "from sklearn.datasets import fetch_openml\n", |
133 | | - "from sklearn.metrics import mean_absolute_error, accuracy_score, roc_curve, roc_auc_score\n", |
| 133 | + "from sklearn.metrics import mean_absolute_error, accuracy_score, roc_curve, roc_auc_score, log_loss, mean_squared_error\n", |
134 | 134 | "from sklearn.tree import DecisionTreeRegressor\n", |
135 | 135 | "from sklearn.datasets import make_regression\n", |
136 | 136 | "from sklearn.preprocessing import OneHotEncoder\n", |
|
144 | 144 | "from sklearn.linear_model import LogisticRegression\n", |
145 | 145 | "from sklearn.neighbors import KNeighborsClassifier\n", |
146 | 146 | "from sklearn.ensemble import GradientBoostingClassifier\n", |
| 147 | + "from sklearn.ensemble import GradientBoostingRegressor\n", |
| 148 | + "from sklearn.linear_model import ElasticNet\n", |
147 | 149 | "\n", |
148 | 150 | "warnings.filterwarnings(\"ignore\")\n", |
149 | 151 | "\n", |
|
354 | 356 | "outputs": [], |
355 | 357 | "source": [ |
356 | 358 | "# core_model = RidgeCV\n", |
357 | | - "# core_model = RandomForestClassifier\n", |
| 359 | + "# core_model = GradientBoostingRegressor\n", |
| 360 | + "# core_model = ElasticNet\n", |
| 361 | + "core_model = RandomForestClassifier\n", |
358 | 362 | "# core_model = SVC\n", |
359 | 363 | "# core_model = LogisticRegression\n", |
360 | 364 | "# core_model = KNeighborsClassifier\n", |
361 | | - "core_model = GradientBoostingClassifier\n", |
| 365 | + "# core_model = GradientBoostingClassifier\n", |
362 | 366 | "fun_control = add_core_model_to_fun_control(core_model=core_model,\n", |
363 | 367 | " fun_control=fun_control,\n", |
364 | 368 | " hyper_dict=SklearnHyperDict,\n", |
|
436 | 440 | "outputs": [], |
437 | 441 | "source": [ |
438 | 442 | "fun = HyperSklearn(seed=123, log_level=50).fun_sklearn\n", |
439 | | - "weights = -1.0\n", |
440 | | - "\n", |
| 443 | + "# metric_sklearn = roc_auc_score\n", |
| 444 | + "# weights = -1.0\n", |
| 445 | + "metric_sklearn = log_loss\n", |
| 446 | + "weights = 1.0\n", |
441 | 447 | "\n", |
442 | 448 | "fun_control.update({\n", |
443 | 449 | " \"horizon\": None,\n", |
|
447 | 453 | " \"log_level\": 50,\n", |
448 | 454 | " \"weight_coeff\": None,\n", |
449 | 455 | " \"metric\": None,\n", |
450 | | - " \"metric_sklearn\": roc_auc_score\n", |
| 456 | + " \"metric_sklearn\": metric_sklearn\n", |
451 | 457 | " })" |
452 | 458 | ] |
453 | 459 | }, |
|
499 | 505 | "name": "stdout", |
500 | 506 | "output_type": "stream", |
501 | 507 | "text": [ |
502 | | - "| name | type | default | lower | upper | transform |\n", |
503 | | - "|--------------------------|--------|--------------|---------|---------|------------------------|\n", |
504 | | - "| loss | factor | log_loss | 0 | 1 | None |\n", |
505 | | - "| learning_rate | float | 0.1 | 0.001 | 0.2 | None |\n", |
506 | | - "| n_estimators | int | 7 | 3 | 10 | transform_power_2_int |\n", |
507 | | - "| subsample | float | 0.0 | -10 | 0 | transform_power_2 |\n", |
508 | | - "| criterion | factor | friedman_mse | 0 | 1 | None |\n", |
509 | | - "| min_samples_split | int | 1 | 1 | 10 | transform_power_2_int |\n", |
510 | | - "| min_samples_leaf | int | 0 | 0 | 10 | transform_power_2_int |\n", |
511 | | - "| min_weight_fraction_leaf | float | 0.0 | 0 | 0.5 | None |\n", |
512 | | - "| max_depth | int | 2 | 1 | 20 | transform_power_2_int |\n", |
513 | | - "| min_impurity_decrease | float | 0.0 | 0 | 1e+06 | None |\n", |
514 | | - "| max_features | factor | none | 0 | 3 | transform_none_to_None |\n", |
515 | | - "| max_leaf_nodes | int | 10 | 1 | 12 | transform_power_2_int |\n", |
516 | | - "| tol | float | 0.0001 | 1e-05 | 0.001 | None |\n" |
| 508 | + "| name | type | default | lower | upper | transform |\n", |
| 509 | + "|--------------------------|--------|-----------|---------|---------|------------------------|\n", |
| 510 | + "| n_estimators | int | 7 | 5 | 9 | transform_power_2_int |\n", |
| 511 | + "| criterion | factor | gini | 0 | 2 | None |\n", |
| 512 | + "| max_depth | int | 10 | 1 | 20 | transform_power_2_int |\n", |
| 513 | + "| min_samples_split | int | 2 | 2 | 100 | None |\n", |
| 514 | + "| min_samples_leaf | int | 1 | 1 | 10 | None |\n", |
| 515 | + "| min_weight_fraction_leaf | float | 0.0 | 0 | 0.01 | None |\n", |
| 516 | + "| max_features | factor | sqrt | 0 | 1 | transform_none_to_None |\n", |
| 517 | + "| max_leaf_nodes | int | 10 | 7 | 12 | transform_power_2_int |\n", |
| 518 | + "| min_impurity_decrease | float | 0.0 | 0 | 0.01 | None |\n", |
| 519 | + "| bootstrap | factor | 1 | 0 | 1 | None |\n" |
517 | 520 | ] |
518 | 521 | } |
519 | 522 | ], |
|
540 | 543 | { |
541 | 544 | "data": { |
542 | 545 | "text/plain": [ |
543 | | - "array([[0.e+00, 1.e-01, 7.e+00, 0.e+00, 0.e+00, 1.e+00, 0.e+00, 0.e+00,\n", |
544 | | - " 2.e+00, 0.e+00, 3.e+00, 1.e+01, 1.e-04]])" |
| 546 | + "array([[ 7., 0., 10., 2., 1., 0., 0., 10., 0., 1.]])" |
545 | 547 | ] |
546 | 548 | }, |
547 | 549 | "execution_count": 18, |
|
565 | 567 | "name": "stdout", |
566 | 568 | "output_type": "stream", |
567 | 569 | "text": [ |
568 | | - "spotPython tuning: [##########] 97.10% \r" |
| 570 | + "spotPython tuning: [##--------] 23.07% \r" |
569 | 571 | ] |
570 | 572 | } |
571 | 573 | ], |
|
0 commit comments