@@ -94,15 +94,19 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
94
94
95
95
# fit model and get results
96
96
verbose = verbosity > 0 # Display fitting operations if verbosity > 0
97
- fitresult = ParallelKMeans. kmeans (algo, DMatrix, m. k;
97
+ result = ParallelKMeans. kmeans (algo, DMatrix, m. k;
98
98
n_threads = m. threads, k_init= m. k_init,
99
99
max_iters= m. max_iters, tol= m. tol, init= m. init,
100
100
verbose= verbose)
101
101
102
+ cluster_labels = MMI. categorical (1 : m. k)
103
+ fitresult = (centers = result. centers, labels = cluster_labels, converged = result. converged)
102
104
cache = nothing
103
- report = (cluster_centers= fitresult. centers, iterations= fitresult. iterations,
104
- converged= fitresult. converged, totalcost= fitresult. totalcost,
105
- labels= fitresult. assignments)
105
+
106
+ report = (cluster_centers= result. centers, iterations= result. iterations,
107
+ totalcost= result. totalcost, assignments= result. assignments, labels= cluster_labels)
108
+
109
+
106
110
"""
107
111
# TODO: warn users about non convergence
108
112
if verbose & (!fitresult.converged)
114
118
115
119
116
120
function MMI. fitted_params (model:: KMeans , fitresult)
117
- # extract what's relevant from `fitresult`
118
- results, _, _ = fitresult # unpack fitresult
119
- centers = results. centers
120
- converged = results. converged
121
- iters = results. iterations
122
- totalcost = results. totalcost
123
-
124
- # then return as a NamedTuple
125
- return (cluster_centers = centers, totalcost = totalcost,
126
- iterations = iters, converged = converged)
121
+ # Centroids
122
+ return (cluster_centers = fitresult. centers, )
127
123
end
128
124
129
125
132
128
# ###
133
129
134
130
function MMI. transform (m:: KMeans , fitresult, Xnew)
135
- # make predictions/assignments using the learned centroids
131
+ # transform new data using the fitted centroids.
136
132
137
133
if ! m. copy
138
134
# permutes dimensions of input table without copying and pass to model
@@ -143,21 +139,36 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
143
139
end
144
140
145
141
# Warn users if fitresult is from a `non-converged` fit
146
- if ! fitresult[ end ] . converged
142
+ if ! fitresult. converged
147
143
@warn " Failed to converge. Using last assignments to make transformations."
148
144
end
149
145
150
- # results from fitted model
151
- results = fitresult[1 ]
152
-
153
146
# use centroid matrix to assign clusters for new data
154
- centroids = results. centers
155
- distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, centroids; dims= 2 )
156
- preds = argmin .(eachrow (distances))
157
- return MMI. table (reshape (preds, :, 1 ), prototype= Xnew)
147
+ distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, fitresult. centers; dims= 2 )
148
+ # preds = argmin.(eachrow(distances))
149
+ return MMI. table (distances, prototype= Xnew)
158
150
end
159
151
160
152
153
+ function MMI. predict (m:: KMeans , fitresult, Xnew)
154
+ locations, cluster_labels, _ = fitresult
155
+
156
+ Xarray = MMI. matrix (Xnew)
157
+ (n, p), k = size (Xarray), m. k
158
+
159
+ pred = zeros (Int, n)
160
+ @inbounds for i ∈ 1 : n
161
+ minv = Inf
162
+ for j ∈ 1 : k
163
+ curv = Distances. evaluate (Distances. Euclidean (), view (Xarray, i, :), view (locations, :, j))
164
+ P = curv < minv
165
+ pred[i] = j * P + pred[i] * ! P # if P is true --> j
166
+ minv = curv * P + minv * ! P # if P is true --> curvalue
167
+ end
168
+ end
169
+ return cluster_labels[pred]
170
+ end
171
+
161
172
# ###
162
173
# ### METADATA
163
174
# ###
@@ -176,6 +187,7 @@ MMI.metadata_pkg.(KMeans,
176
187
MMI. metadata_model (KMeans,
177
188
input = MMI. Table (MMI. Continuous),
178
189
output = MMI. Table (MMI. Continuous),
190
+ target = AbstractArray{<: MMI.Multiclass },
179
191
weights = false ,
180
192
descr = ParallelKMeans_Desc,
181
193
path = " ParallelKMeans.KMeans" )
0 commit comments