|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +This is the Kriging surrogate model. |
| 4 | +It is based on the DACE matlab toolbox. |
| 5 | +It can handle numerical and categorical variables. |
| 6 | +""" |
| 7 | +import numpy as np |
| 8 | +from scipy.spatial.distance import cdist |
| 9 | +from scipy.linalg import cholesky, cho_solve, solve_triangular |
| 10 | + |
| 11 | + |
| 12 | +class Kriging: |
| 13 | + """ |
| 14 | + Kriging class with optional Nyström approximation for scalability. |
| 15 | + This class implements the Kriging surrogate model, also known as |
| 16 | + Gaussian Process regression. It is adapted to handle both numerical |
| 17 | + (ordered) and categorical (factor) variables, a key feature of spotpython. |
| 18 | + The Nyström approximation is added as an optional feature to handle |
| 19 | + large datasets efficiently. |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__(self, fun_control, n_theta=None, theta=None, p=2.0, corr="squared_exponential", isotropic=False, approximation="None", n_landmarks=100): |
| 23 | + """ |
| 24 | + Initialize the Kriging model. |
| 25 | +
|
| 26 | + Args: |
| 27 | + fun_control (dict): Control dictionary from spotpython, containing |
| 28 | + problem dimensions, variable types ('var_type'), etc. |
| 29 | + n_theta (int, optional): Number of correlation parameters (theta). |
| 30 | + Defaults to problem dimension for anisotropic model. |
| 31 | + theta (np.ndarray, optional): Initial correlation parameters. |
| 32 | + Defaults to 0.1 for all dimensions. |
| 33 | + p (float, optional): Power for the correlation function. Defaults to 2.0. |
| 34 | + corr (str, optional): Correlation function type. |
| 35 | + Defaults to "squared_exponential". |
| 36 | + isotropic (bool, optional): Whether to use an isotropic model (one theta |
| 37 | + for all dimensions). Defaults to False. |
| 38 | + approximation (str, optional): Type of approximation to use. |
| 39 | + "None" for standard Kriging, |
| 40 | + "nystroem" for Nyström approximation. |
| 41 | + Defaults to "None". |
| 42 | + n_landmarks (int, optional): Number of landmark points for Nyström. |
| 43 | + Only used if approximation="nystroem". |
| 44 | + Defaults to 100. |
| 45 | + """ |
| 46 | + self.fun_control = fun_control |
| 47 | + self.dim = self.fun_control["lower"].shape |
| 48 | + self.p = p |
| 49 | + self.corr = corr |
| 50 | + self.isotropic = isotropic |
| 51 | + self.approximation = approximation |
| 52 | + self.n_landmarks = n_landmarks |
| 53 | + |
| 54 | + # Setup masks for variable types |
| 55 | + self.factor_mask = self.fun_control["var_type"] == "factor" |
| 56 | + self.ordered_mask = ~self.factor_mask |
| 57 | + |
| 58 | + # Determine number of theta parameters |
| 59 | + if self.isotropic: |
| 60 | + self.n_theta = 1 |
| 61 | + elif n_theta is None: |
| 62 | + self.n_theta = self.dim |
| 63 | + else: |
| 64 | + self.n_theta = n_theta |
| 65 | + |
| 66 | + # Initialize theta |
| 67 | + if theta is None: |
| 68 | + self.theta = np.full(self.n_theta, 0.1) |
| 69 | + else: |
| 70 | + self.theta = theta |
| 71 | + |
| 72 | + # Model state attributes |
| 73 | + self.X_ = None |
| 74 | + self.y_ = None |
| 75 | + self.L_ = None # Cholesky factor for standard Kriging |
| 76 | + self.alpha_ = None # Solved term for standard Kriging |
| 77 | + |
| 78 | + # Nyström-specific attributes |
| 79 | + self.landmarks_ = None |
| 80 | + self.W_cho_ = None # Cholesky factor of W matrix |
| 81 | + self.nystrom_alpha_ = None # Solved term for Nyström prediction |
| 82 | + |
| 83 | + def fit(self, X, y): |
| 84 | + """ |
| 85 | + Fit the Kriging model to the training data. |
| 86 | +
|
| 87 | + Args: |
| 88 | + X (np.ndarray): Training data of shape (n_samples, n_features). |
| 89 | + y (np.ndarray): Target values of shape (n_samples,). |
| 90 | + """ |
| 91 | + self.X_ = X |
| 92 | + self.y_ = y |
| 93 | + n_samples = X.shape |
| 94 | + |
| 95 | + if self.approximation.lower() == "nystroem": |
| 96 | + if n_samples <= self.n_landmarks: |
| 97 | + # Fallback to standard Kriging if not enough samples |
| 98 | + self._fit_standard(X, y) |
| 99 | + else: |
| 100 | + self._fit_nystrom(X, y) |
| 101 | + else: |
| 102 | + self._fit_standard(X, y) |
| 103 | + |
| 104 | + def _fit_standard(self, X, y): |
| 105 | + """Standard Kriging fitting procedure.""" |
| 106 | + # Build the full covariance matrix Psi |
| 107 | + Psi = self.build_Psi(X, X) |
| 108 | + Psi[np.diag_indices_from(Psi)] += 1e-8 # Add jitter for stability |
| 109 | + |
| 110 | + try: |
| 111 | + # Compute Cholesky decomposition |
| 112 | + self.L_ = cholesky(Psi, lower=True) |
| 113 | + # Solve for alpha = L'\(L\y) |
| 114 | + self.alpha_ = cho_solve((self.L_, True), y) |
| 115 | + except np.linalg.LinAlgError: |
| 116 | + # Fallback to pseudo-inverse if Cholesky fails |
| 117 | + pi_Psi = np.linalg.pinv(Psi) |
| 118 | + self.alpha_ = np.dot(pi_Psi, y) |
| 119 | + self.L_ = None # Indicate that Cholesky failed |
| 120 | + |
| 121 | + def _fit_nystrom(self, X, y): |
| 122 | + """Nyström approximation fitting procedure.""" |
| 123 | + n_samples = X.shape |
| 124 | + |
| 125 | + # 1. Select landmark points using uniform random sampling without replacement |
| 126 | + landmark_indices = np.random.choice(n_samples, self.n_landmarks, replace=False) |
| 127 | + self.landmarks_ = X[landmark_indices, :] |
| 128 | + |
| 129 | + # 2. Construct core matrices using build_Psi |
| 130 | + # W = K_mm (landmark-landmark covariance) |
| 131 | + W = self.build_Psi(self.landmarks_, self.landmarks_) |
| 132 | + W += 1e-8 # Add jitter |
| 133 | + |
| 134 | + # C = K_nm (data-landmark cross-covariance) |
| 135 | + C = self.build_Psi(X, self.landmarks_) |
| 136 | + |
| 137 | + # 3. Compute Cholesky decomposition of W |
| 138 | + try: |
| 139 | + self.W_cho_ = cholesky(W, lower=True) |
| 140 | + except np.linalg.LinAlgError: |
| 141 | + self.W_cho_ = None |
| 142 | + # Fallback to standard Kriging as a safe option |
| 143 | + self._fit_standard(X, y) |
| 144 | + return |
| 145 | + |
| 146 | + # 4. Pre-compute terms for prediction |
| 147 | + # Solve for nystrom_alpha = W_inv * C.T * y |
| 148 | + Ct_y = C.T @ y |
| 149 | + self.nystrom_alpha_ = cho_solve((self.W_cho_, True), Ct_y) |
| 150 | + |
| 151 | + def predict(self, X_star): |
| 152 | + """ |
| 153 | + Make predictions with the fitted Kriging model. |
| 154 | +
|
| 155 | + Args: |
| 156 | + X_star (np.ndarray): Test data of shape (n_test_samples, n_features). |
| 157 | +
|
| 158 | + Returns: |
| 159 | + tuple: A tuple containing: |
| 160 | + - y_pred (np.ndarray): Predicted mean values. |
| 161 | + - y_mse (np.ndarray): Mean squared error (predictive variance). |
| 162 | + """ |
| 163 | + if self.approximation.lower() == "nystroem" and self.landmarks_ is not None: |
| 164 | + return self._predict_nystrom(X_star) |
| 165 | + else: |
| 166 | + return self._predict_standard(X_star) |
| 167 | + |
| 168 | + def _predict_standard(self, X_star): |
| 169 | + """Standard Kriging prediction procedure.""" |
| 170 | + # Build cross-covariance vector/matrix psi |
| 171 | + psi = self.build_Psi(X_star, self.X_) |
| 172 | + |
| 173 | + # Predictive mean |
| 174 | + y_pred = psi @ self.alpha_ |
| 175 | + |
| 176 | + # Predictive variance |
| 177 | + if self.L_ is not None: |
| 178 | + v = solve_triangular(self.L_, psi.T, lower=True) |
| 179 | + y_mse = 1.0 - np.sum(v**2, axis=0) |
| 180 | + y_mse[y_mse < 0] = 0 |
| 181 | + else: |
| 182 | + pi_Psi = np.linalg.pinv(self.build_Psi(self.X_, self.X_) + 1e-8 * np.eye(self.X_.shape)) |
| 183 | + y_mse = 1.0 - np.sum((psi @ pi_Psi) * psi, axis=1) |
| 184 | + y_mse[y_mse < 0] = 0 |
| 185 | + |
| 186 | + return y_pred, y_mse.reshape(-1, 1) |
| 187 | + |
| 188 | + def _predict_nystrom(self, X_star): |
| 189 | + """Nyström approximation prediction procedure.""" |
| 190 | + # 1. Compute cross-covariance between test points and landmarks |
| 191 | + psi_star_m = self.build_Psi(X_star, self.landmarks_) |
| 192 | + |
| 193 | + # 2. Predictive mean |
| 194 | + y_pred = psi_star_m @ self.nystrom_alpha_ |
| 195 | + |
| 196 | + # 3. Predictive variance |
| 197 | + if self.W_cho_ is not None: |
| 198 | + v = cho_solve((self.W_cho_, True), psi_star_m.T) |
| 199 | + quad_term = np.sum(psi_star_m * v.T, axis=1) |
| 200 | + y_mse = 1.0 - quad_term |
| 201 | + y_mse[y_mse < 0] = 0 |
| 202 | + else: |
| 203 | + y_mse = np.ones(X_star.shape) # Return max uncertainty |
| 204 | + |
| 205 | + return y_pred, y_mse.reshape(-1, 1) |
| 206 | + |
| 207 | + def build_Psi(self, X1, X2): |
| 208 | + """Builds the covariance matrix Psi between two sets of points.""" |
| 209 | + n1 = X1.shape |
| 210 | + Psi = np.zeros((n1, X2.shape)) |
| 211 | + for i in range(n1): |
| 212 | + Psi[i, :] = self.build_psi_vec(X1[i, :], X2) |
| 213 | + return Psi |
| 214 | + |
| 215 | + def build_psi_vec(self, x, X_): |
| 216 | + """ |
| 217 | + Builds a covariance vector between a point x and a set of points X_. |
| 218 | + This method correctly handles mixed (ordered/factor) variable types. |
| 219 | + """ |
| 220 | + # Handle theta for isotropic vs. anisotropic cases |
| 221 | + if self.isotropic: |
| 222 | + theta10 = np.full(self.dim, 10**self.theta) |
| 223 | + else: |
| 224 | + theta10 = 10**self.theta |
| 225 | + |
| 226 | + D = np.zeros(X_.shape) |
| 227 | + |
| 228 | + # Compute ordered distance contributions |
| 229 | + if self.ordered_mask.any(): |
| 230 | + X_ordered = X_[:, self.ordered_mask] |
| 231 | + x_ordered = x[self.ordered_mask] |
| 232 | + D += cdist(x_ordered.reshape(1, -1), X_ordered, metric="sqeuclidean", w=theta10[self.ordered_mask]).ravel() |
| 233 | + |
| 234 | + # Compute factor distance contributions |
| 235 | + if self.factor_mask.any(): |
| 236 | + X_factor = X_[:, self.factor_mask] |
| 237 | + x_factor = x[self.factor_mask] |
| 238 | + # Hamming distance for factors |
| 239 | + D += cdist(x_factor.reshape(1, -1), X_factor, metric="hamming", w=theta10[self.factor_mask]).ravel() * self.factor_mask.sum() |
| 240 | + |
| 241 | + # Apply correlation function |
| 242 | + if self.corr == "squared_exponential": |
| 243 | + psi = np.exp(-D) |
| 244 | + else: |
| 245 | + # Fallback for other potential correlation functions |
| 246 | + psi = np.exp(-(D**self.p)) |
| 247 | + |
| 248 | + return psi |
0 commit comments