Skip to content

Commit 990064f

Browse files
committed
Switched fitresult to namedtuple type
1 parent 997f747 commit 990064f

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

src/mlj_interface.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,11 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
100100
verbose=verbose)
101101

102102
cluster_labels = MMI.categorical(1:m.k)
103-
fitresult = (result.centers, cluster_labels, result.converged)
103+
fitresult = (centers = result.centers, labels = cluster_labels, converged = result.converged)
104104
cache = nothing
105105

106106
report = (cluster_centers=result.centers, iterations=result.iterations,
107-
converged=result.converged, totalcost=result.totalcost,
108-
assignments=result.assignments, labels=cluster_labels)
107+
totalcost=result.totalcost, assignments=result.assignments, labels=cluster_labels)
109108

110109

111110
"""
@@ -120,7 +119,7 @@ end
120119

121120
function MMI.fitted_params(model::KMeans, fitresult)
122121
# Centroids
123-
return (cluster_centers = fitresult[1], )
122+
return (cluster_centers = fitresult.centers, )
124123
end
125124

126125

@@ -129,7 +128,7 @@ end
129128
####
130129

131130
function MMI.transform(m::KMeans, fitresult, Xnew)
132-
# make predictions/assignments using the learned centroids
131+
# transform new data using the fitted centroids.
133132

134133
if !m.copy
135134
# permutes dimensions of input table without copying and pass to model
@@ -140,13 +139,12 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
140139
end
141140

142141
# Warn users if fitresult is from a `non-converged` fit
143-
if !(fitresult[end])
142+
if !(fitresult.converged)
144143
@warn "Failed to converge. Using last assignments to make transformations."
145144
end
146145

147146
# use centroid matrix to assign clusters for new data
148-
centroids = fitresult[1]
149-
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, centroids; dims=2)
147+
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, fitresult.centers; dims=2)
150148
#preds = argmin.(eachrow(distances))
151149
return MMI.table(distances, prototype=Xnew)
152150
end

test/test07_mlj_interface.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ end
5050
model = KMeans(algo = :Lloyd, k=2)
5151
results, cache, report = fit(model, 0, X)
5252

53-
@test cache == nothing
54-
@test report.converged == true
55-
@test report.totalcost == 16
53+
@test cache == nothing
54+
@test results.converged == true
55+
@test report.totalcost == 16
5656

5757
params = fitted_params(model, results)
5858
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
@@ -76,9 +76,9 @@ end
7676
model = KMeans(algo = :Hamerly, k=2)
7777
results, cache, report = fit(model, 0, X)
7878

79-
@test cache == nothing
80-
@test report.converged == true
81-
@test report.totalcost == 16
79+
@test cache == nothing
80+
@test results.converged == true
81+
@test report.totalcost == 16
8282

8383
params = fitted_params(model, results)
8484
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
@@ -102,9 +102,9 @@ end
102102
model = KMeans(algo = :Elkan, k=2)
103103
results, cache, report = fit(model, 0, X)
104104

105-
@test cache == nothing
106-
@test report.converged == true
107-
@test report.totalcost == 16
105+
@test cache == nothing
106+
@test results.converged == true
107+
@test report.totalcost == 16
108108

109109
params = fitted_params(model, results)
110110
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]

0 commit comments

Comments
 (0)