@@ -100,12 +100,11 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
100
100
verbose= verbose)
101
101
102
102
cluster_labels = MMI. categorical (1 : m. k)
103
- fitresult = (result. centers, cluster_labels, result. converged)
103
+ fitresult = (centers = result. centers, labels = cluster_labels, converged = result. converged)
104
104
cache = nothing
105
105
106
106
report = (cluster_centers= result. centers, iterations= result. iterations,
107
- converged= result. converged, totalcost= result. totalcost,
108
- assignments= result. assignments, labels= cluster_labels)
107
+ totalcost= result. totalcost, assignments= result. assignments, labels= cluster_labels)
109
108
110
109
111
110
"""
120
119
121
120
function MMI. fitted_params (model:: KMeans , fitresult)
122
121
# Centroids
123
- return (cluster_centers = fitresult[ 1 ] , )
122
+ return (cluster_centers = fitresult. centers , )
124
123
end
125
124
126
125
129
128
# ###
130
129
131
130
function MMI. transform (m:: KMeans , fitresult, Xnew)
132
- # make predictions/assignments using the learned centroids
131
+ # transform new data using the fitted centroids.
133
132
134
133
if ! m. copy
135
134
# permutes dimensions of input table without copying and pass to model
@@ -140,13 +139,12 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
140
139
end
141
140
142
141
# Warn users if fitresult is from a `non-converged` fit
143
- if ! (fitresult[ end ] )
142
+ if ! (fitresult. converged )
144
143
@warn " Failed to converge. Using last assignments to make transformations."
145
144
end
146
145
147
146
# use centroid matrix to assign clusters for new data
148
- centroids = fitresult[1 ]
149
- distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, centroids; dims= 2 )
147
+ distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, fitresult. centers; dims= 2 )
150
148
# preds = argmin.(eachrow(distances))
151
149
return MMI. table (distances, prototype= Xnew)
152
150
end
0 commit comments