Skip to content

Commit 68d5363

Browse files
authored
Merge pull request #61 from PyDataBlog/experimental
Updated MLJ Interface with predict function
2 parents 148167e + 18bacb1 commit 68d5363

File tree

2 files changed

+71
-47
lines changed

2 files changed

+71
-47
lines changed

src/mlj_interface.jl

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,19 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
9494

9595
# fit model and get results
9696
verbose = verbosity > 0 # Display fitting operations if verbosity > 0
97-
fitresult = ParallelKMeans.kmeans(algo, DMatrix, m.k;
97+
result = ParallelKMeans.kmeans(algo, DMatrix, m.k;
9898
n_threads = m.threads, k_init=m.k_init,
9999
max_iters=m.max_iters, tol=m.tol, init=m.init,
100100
verbose=verbose)
101101

102+
cluster_labels = MMI.categorical(1:m.k)
103+
fitresult = (centers = result.centers, labels = cluster_labels, converged = result.converged)
102104
cache = nothing
103-
report = (cluster_centers=fitresult.centers, iterations=fitresult.iterations,
104-
converged=fitresult.converged, totalcost=fitresult.totalcost,
105-
labels=fitresult.assignments)
105+
106+
report = (cluster_centers=result.centers, iterations=result.iterations,
107+
totalcost=result.totalcost, assignments=result.assignments, labels=cluster_labels)
108+
109+
106110
"""
107111
# TODO: warn users about non convergence
108112
if verbose & (!fitresult.converged)
@@ -114,16 +118,8 @@ end
114118

115119

116120
function MMI.fitted_params(model::KMeans, fitresult)
117-
# extract what's relevant from `fitresult`
118-
results, _, _ = fitresult # unpack fitresult
119-
centers = results.centers
120-
converged = results.converged
121-
iters = results.iterations
122-
totalcost = results.totalcost
123-
124-
# then return as a NamedTuple
125-
return (cluster_centers = centers, totalcost = totalcost,
126-
iterations = iters, converged = converged)
121+
# Centroids
122+
return (cluster_centers = fitresult.centers, )
127123
end
128124

129125

@@ -132,7 +128,7 @@ end
132128
####
133129

134130
function MMI.transform(m::KMeans, fitresult, Xnew)
135-
# make predictions/assignments using the learned centroids
131+
# transform new data using the fitted centroids.
136132

137133
if !m.copy
138134
# permutes dimensions of input table without copying and pass to model
@@ -143,21 +139,36 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
143139
end
144140

145141
# Warn users if fitresult is from a `non-converged` fit
146-
if !fitresult[end].converged
142+
if !fitresult.converged
147143
@warn "Failed to converge. Using last assignments to make transformations."
148144
end
149145

150-
# results from fitted model
151-
results = fitresult[1]
152-
153146
# use centroid matrix to assign clusters for new data
154-
centroids = results.centers
155-
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, centroids; dims=2)
156-
preds = argmin.(eachrow(distances))
157-
return MMI.table(reshape(preds, :, 1), prototype=Xnew)
147+
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, fitresult.centers; dims=2)
148+
#preds = argmin.(eachrow(distances))
149+
return MMI.table(distances, prototype=Xnew)
158150
end
159151

160152

153+
function MMI.predict(m::KMeans, fitresult, Xnew)
154+
locations, cluster_labels, _ = fitresult
155+
156+
Xarray = MMI.matrix(Xnew)
157+
(n, p), k = size(Xarray), m.k
158+
159+
pred = zeros(Int, n)
160+
@inbounds for i 1:n
161+
minv = Inf
162+
for j 1:k
163+
curv = Distances.evaluate(Distances.Euclidean(), view(Xarray, i, :), view(locations, :, j))
164+
P = curv < minv
165+
pred[i] = j * P + pred[i] * !P # if P is true --> j
166+
minv = curv * P + minv * !P # if P is true --> curvalue
167+
end
168+
end
169+
return cluster_labels[pred]
170+
end
171+
161172
####
162173
#### METADATA
163174
####
@@ -176,6 +187,7 @@ MMI.metadata_pkg.(KMeans,
176187
MMI.metadata_model(KMeans,
177188
input = MMI.Table(MMI.Continuous),
178189
output = MMI.Table(MMI.Continuous),
190+
target = AbstractArray{<:MMI.Multiclass},
179191
weights = false,
180192
descr = ParallelKMeans_Desc,
181193
path = "ParallelKMeans.KMeans")

test/test07_mlj_interface.jl

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,23 @@ end
4848
X_test = table([10 1])
4949

5050
model = KMeans(algo = :Lloyd, k=2)
51-
results = fit(model, 0, X)
51+
results, cache, report = fit(model, 0, X)
5252

53-
@test results[2] == nothing
54-
@test results[end].converged == true
55-
@test results[end].totalcost == 16
53+
@test cache == nothing
54+
@test results.converged == true
55+
@test report.totalcost == 16
5656

5757
params = fitted_params(model, results)
58-
@test params.converged == true
59-
@test params.totalcost == 16
58+
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
6059

6160
# Use trained model to cluster new data X_test
6261
preds = transform(model, results, X_test)
63-
@test preds[:x1][1] == 2
62+
@test preds[:x1][1] == 82.0
63+
@test preds[:x2][1] == 1.0
64+
65+
# Make predictions on new data X_test with fitted params
66+
yhat = predict(model, results, X_test)
67+
@test yhat[1] == 2
6468
end
6569

6670

@@ -69,20 +73,24 @@ end
6973
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
7074
X_test = table([10 1])
7175

72-
model = KMeans(algo=:Hamerly, k=2)
73-
results = fit(model, 0, X)
76+
model = KMeans(algo = :Hamerly, k=2)
77+
results, cache, report = fit(model, 0, X)
7478

75-
@test results[2] == nothing
76-
@test results[end].converged == true
77-
@test results[end].totalcost == 16
79+
@test cache == nothing
80+
@test results.converged == true
81+
@test report.totalcost == 16
7882

7983
params = fitted_params(model, results)
80-
@test params.converged == true
81-
@test params.totalcost == 16
84+
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
8285

8386
# Use trained model to cluster new data X_test
8487
preds = transform(model, results, X_test)
85-
@test preds[:x1][1] == 2
88+
@test preds[:x1][1] == 82.0
89+
@test preds[:x2][1] == 1.0
90+
91+
# Make predictions on new data X_test with fitted params
92+
yhat = predict(model, results, X_test)
93+
@test yhat[1] == 2
8694
end
8795

8896

@@ -91,20 +99,24 @@ end
9199
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
92100
X_test = table([10 1])
93101

94-
model = KMeans(algo=:Elkan, k=2)
95-
results = fit(model, 0, X)
102+
model = KMeans(algo = :Elkan, k=2)
103+
results, cache, report = fit(model, 0, X)
96104

97-
@test results[2] == nothing
98-
@test results[end].converged == true
99-
@test results[end].totalcost == 16
105+
@test cache == nothing
106+
@test results.converged == true
107+
@test report.totalcost == 16
100108

101109
params = fitted_params(model, results)
102-
@test params.converged == true
103-
@test params.totalcost == 16
110+
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
104111

105112
# Use trained model to cluster new data X_test
106113
preds = transform(model, results, X_test)
107-
@test preds[:x1][1] == 2
114+
@test preds[:x1][1] == 82.0
115+
@test preds[:x2][1] == 1.0
116+
117+
# Make predictions on new data X_test with fitted params
118+
yhat = predict(model, results, X_test)
119+
@test yhat[1] == 2
108120
end
109121

110122

@@ -114,7 +126,7 @@ end
114126
X_test = table([10 1])
115127

116128
model = KMeans(k=2, max_iters=1)
117-
results = fit(model, 0, X)
129+
results, cache, report = fit(model, 0, X)
118130

119131
@test_logs (:warn, "Failed to converge. Using last assignments to make transformations.") transform(model, results, X_test)
120132
end

0 commit comments

Comments
 (0)