Skip to content

Commit e9d8832

Browse files
author
Andrey Oskin
committed
added parallelize macro and docstrings
1 parent c4d1a2e commit e9d8832

File tree

3 files changed

+136
-98
lines changed

3 files changed

+136
-98
lines changed

src/hamerly.jl

Lines changed: 62 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,24 @@ function kmeans!(alg::Hamerly, containers, design_matrix, k;
1919
nrow, ncol = size(design_matrix)
2020
centroids = init == nothing ? smart_init(design_matrix, k, n_threads, init=k_init).centroids : deepcopy(init)
2121

22-
initialize!(alg, containers, centroids, design_matrix, n_threads)
22+
@parallelize n_threads ncol chunk_initialize!(alg, containers, centroids, design_matrix)
2323

2424
converged = false
2525
niters = 1
2626
J_previous = 0.0
27+
p = containers.p
2728

2829
# Update centroids & labels with closest members until convergence
29-
3030
while niters <= max_iters
3131
update_containers!(containers, alg, centroids, n_threads)
32-
update_centroids!(centroids, containers, alg, design_matrix, n_threads)
32+
@parallelize n_threads ncol chunk_update_centroids!(centroids, containers, alg, design_matrix)
33+
collect_containers(alg, containers, n_threads)
34+
3335
J = sum(containers.ub)
3436
move_centers!(centroids, containers, alg)
35-
update_bounds!(containers, n_threads)
37+
38+
r1, r2, pr1, pr2 = double_argmax(p)
39+
@parallelize n_threads ncol chunk_update_bounds!(containers, r1, r2, pr1, pr2)
3640

3741
if verbose
3842
# Show progress and terminate if J stopped decreasing.
@@ -49,7 +53,8 @@ function kmeans!(alg::Hamerly, containers, design_matrix, k;
4953
niters += 1
5054
end
5155

52-
totalcost = sum_of_squares(design_matrix, containers.labels, centroids)
56+
@parallelize n_threads ncol sum_of_squares(containers, design_matrix, containers.labels, centroids)
57+
totalcost = sum(containers.sum_of_squares)
5358

5459
# Terminate algorithm with the assumption that K-means has converged
5560
if verbose & converged
@@ -101,6 +106,9 @@ function create_containers(alg::Hamerly, k, nrow, ncol, n_threads)
101106
# distance from the center to the closest other center
102107
s = Vector{Float64}(undef, k)
103108

109+
# total_sum_calculation
110+
sum_of_squares = Vector{Float64}(undef, n_threads)
111+
104112
return (
105113
centroids_new = centroids_new,
106114
centroids_cnt = centroids_cnt,
@@ -109,31 +117,15 @@ function create_containers(alg::Hamerly, k, nrow, ncol, n_threads)
109117
lb = lb,
110118
p = p,
111119
s = s,
120+
sum_of_squares = sum_of_squares
112121
)
113122
end
114123

115-
function initialize!(alg::Hamerly, containers, centroids, design_matrix, n_threads)
116-
ncol = size(design_matrix, 2)
117-
118-
if n_threads == 1
119-
r = axes(design_matrix, 2)
120-
chunk_initialize!(alg, containers, centroids, design_matrix, r, 1)
121-
else
122-
ranges = splitter(ncol, n_threads)
123-
124-
waiting_list = Vector{Task}(undef, n_threads - 1)
125-
126-
for i in 1:n_threads - 1
127-
waiting_list[i] = @spawn chunk_initialize!(alg, containers, centroids,
128-
design_matrix, ranges[i], i + 1)
129-
end
130-
131-
chunk_initialize!(alg, containers, centroids, design_matrix, ranges[end], 1)
132-
133-
wait.(waiting_list)
134-
end
135-
end
124+
"""
125+
chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r, idx)
136126
127+
Initial calulation of all bounds and points labeling.
128+
"""
137129
function chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r, idx)
138130
centroids_cnt = containers.centroids_cnt[idx]
139131
centroids_new = containers.centroids_new[idx]
@@ -147,6 +139,11 @@ function chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r
147139
end
148140
end
149141

142+
"""
143+
update_containers!(containers, ::Hamerly, centroids, n_threads)
144+
145+
Calculates minimum distances from centers to each other.
146+
"""
150147
function update_containers!(containers, ::Hamerly, centroids, n_threads)
151148
s = containers.s
152149
s .= Inf
@@ -160,39 +157,14 @@ function update_containers!(containers, ::Hamerly, centroids, n_threads)
160157
end
161158
end
162159

163-
function update_centroids!(centroids, containers, alg::Hamerly, design_matrix, n_threads)
164-
165-
if n_threads == 1
166-
r = axes(design_matrix, 2)
167-
chunk_update_centroids!(centroids, containers, alg, design_matrix, r, 1)
168-
else
169-
ncol = size(design_matrix, 2)
170-
ranges = splitter(ncol, n_threads)
171-
172-
waiting_list = Vector{Task}(undef, n_threads - 1)
173-
174-
for i in 1:length(ranges) - 1
175-
waiting_list[i] = @spawn chunk_update_centroids!(centroids, containers,
176-
alg, design_matrix, ranges[i], i)
177-
end
178-
179-
chunk_update_centroids!(centroids, containers, alg, design_matrix, ranges[end], n_threads)
180-
181-
wait.(waiting_list)
182-
183-
end
184-
185-
collect_containers(alg, containers, n_threads)
186-
end
160+
"""
161+
chunk_update_centroids!(centroids, containers, alg::Hamerly, design_matrix, r, idx)
187162
188-
function chunk_update_centroids!(
189-
centroids,
190-
containers,
191-
alg::Hamerly,
192-
design_matrix,
193-
r,
194-
idx,
195-
)
163+
Detailed description of this function can be found in the original paper. It iterates through
164+
all points and tries to skip some calculation using known upper and lower bounds of distances
165+
from point to centers. If it fails to skip than it fall back to generic `point_all_centers!` function.
166+
"""
167+
function chunk_update_centroids!(centroids, containers, alg::Hamerly, design_matrix, r, idx)
196168

197169
# unpack containers for easier manipulations
198170
centroids_new = containers.centroids_new[idx]
@@ -227,6 +199,11 @@ function chunk_update_centroids!(
227199
end
228200
end
229201

202+
"""
203+
point_all_centers!(containers, centroids, design_matrix, i)
204+
205+
Calculates new labels and upper and lower bounds for all points.
206+
"""
230207
function point_all_centers!(containers, centroids, design_matrix, i)
231208
ub = containers.ub
232209
lb = containers.lb
@@ -253,6 +230,12 @@ function point_all_centers!(containers, centroids, design_matrix, i)
253230
return label
254231
end
255232

233+
"""
234+
move_centers!(centroids, containers, ::Hamerly)
235+
236+
Calculates new positions of centers and distance they have moved. Results are stored
237+
in `centroids` and `p` respectively.
238+
"""
256239
function move_centers!(centroids, containers, ::Hamerly)
257240
centroids_new = containers.centroids_new[end]
258241
p = containers.p
@@ -267,35 +250,28 @@ function move_centers!(centroids, containers, ::Hamerly)
267250
end
268251
end
269252

270-
function update_bounds!(containers, n_threads)
271-
p = containers.p
253+
"""
254+
chunk_update_bounds!(containers, r1, r2, pr1, pr2, r, idx)
272255
273-
r1, r2 = double_argmax(p)
274-
pr1 = p[r1]
275-
pr2 = p[r2]
256+
Updates upper and lower bounds of point distance to the centers, with regard to the centers movement.
257+
Since bounds are squred distance, `sqrt` is used to make corresponding estimation, unlike
258+
the original paper, where usual metric is used.
276259
277-
if n_threads == 1
278-
r = axes(containers.ub, 1)
279-
chunk_update_bounds!(containers, r, r1, r2, pr1, pr2)
280-
else
281-
ncol = length(containers.ub)
282-
ranges = splitter(ncol, n_threads)
260+
Using notation from original paper, `u` is upper bound and `a` is `labels`, so
283261
284-
waiting_list = Vector{Task}(undef, n_threads - 1)
262+
`u[i] -> u[i] + p[a[i]]`
285263
286-
for i in 1:n_threads - 1
287-
waiting_list[i] = @spawn chunk_update_bounds!(containers, ranges[i], r1, r2, pr1, pr2)
288-
end
264+
then squared distance is
289265
290-
chunk_update_bounds!(containers, ranges[end], r1, r2, pr1, pr2)
266+
`u[i]^2 -> (u[i] + p[a[i]])^2 = u[i]^2 + 2 p[a[i]] u[i] + p[a[i]]^2`
291267
292-
for i in 1:n_threads - 1
293-
wait(waiting_list[i])
294-
end
295-
end
296-
end
268+
Taking into account that in our noations `p^2 -> p`, `u^2 -> ub` we obtain
269+
270+
`ub[i] -> ub[i] + 2 sqrt(p[a[i]] ub[i]) + p[a[i]]`
297271
298-
function chunk_update_bounds!(containers, r, r1, r2, pr1, pr2)
272+
The same applies to the lower bounds.
273+
"""
274+
function chunk_update_bounds!(containers, r1, r2, pr1, pr2, r, idx)
299275
p = containers.p
300276
ub = containers.ub
301277
lb = containers.lb
@@ -312,6 +288,11 @@ function chunk_update_bounds!(containers, r, r1, r2, pr1, pr2)
312288
end
313289
end
314290

291+
"""
292+
double_argmax(p)
293+
294+
Finds maximum and next after maximum arguments.
295+
"""
315296
function double_argmax(p)
316297
r1, r2 = 1, 1
317298
d1 = p[1]
@@ -328,19 +309,5 @@ function double_argmax(p)
328309
end
329310
end
330311

331-
r1, r2
332-
end
333-
334-
"""
335-
distance(X1, X2, i1, i2)
336-
337-
Allocation less calculation of square eucledean distance between vectors X1[:, i1] and X2[:, i2]
338-
"""
339-
function distance(X1, X2, i1, i2)
340-
d = 0.0
341-
@inbounds for i in axes(X1, 1)
342-
d += (X1[i, i1] - X2[i, i2])^2
343-
end
344-
345-
return d
312+
r1, r2, d1, d2
346313
end

src/kmeans.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,66 @@ struct KmeansResult{C<:AbstractMatrix{<:AbstractFloat},D<:Real,WC<:Real} <: Clus
4141
converged::Bool # whether the procedure converged
4242
end
4343

44+
"""
45+
@parallelize(n_threads, ncol, f)
46+
47+
Parallelize function and run it over n_threads. Function should require following conditions:
48+
1. It should not return any values.
49+
1. It should accept parameters two parameters at the end of the argument list. First
50+
accepted parameter is `range`, which defines chunk used in calculations. Second
51+
parameter is `idx` which defines id of the container where results can be stored.
52+
53+
`ncol` argument defines range 1:ncol which is sliced in `n_threads` chunks.
54+
"""
55+
macro parallelize(n_threads, ncol, f)
56+
for i in 1:length(f.args)
57+
f.args[i] = :($(esc(f.args[i])))
58+
end
59+
single_thread_chunk = copy(f)
60+
push!(single_thread_chunk.args, :(1:$(esc(ncol))))
61+
push!(single_thread_chunk.args, 1)
62+
63+
multi_thread_chunk = copy(f)
64+
push!(multi_thread_chunk.args, :(ranges[i]))
65+
push!(multi_thread_chunk.args, :(i))
66+
67+
last_multi_thread_chunk = copy(f)
68+
push!(last_multi_thread_chunk.args, :(ranges[end]))
69+
push!(last_multi_thread_chunk.args, :($(esc(n_threads))))
70+
71+
return quote
72+
if $(esc(n_threads)) == 1
73+
$single_thread_chunk
74+
else
75+
local ranges = splitter($(esc(ncol)), $(esc(n_threads)))
76+
local waiting_list = $(esc(Vector)){$(esc(Task))}(undef, $(esc(n_threads)) - 1)
77+
for i in 1:$(esc(n_threads)) - 1
78+
waiting_list[i] = @spawn $multi_thread_chunk
79+
end
80+
81+
$last_multi_thread_chunk
82+
83+
for i in 1:$(esc(n_threads)) - 1
84+
wait(waiting_list[i])
85+
end
86+
end
87+
end
88+
end
89+
90+
"""
91+
distance(X1, X2, i1, i2)
92+
93+
Allocationless calculation of square eucledean distance between vectors X1[:, i1] and X2[:, i2]
94+
"""
95+
function distance(X1, X2, i1, i2)
96+
d = 0.0
97+
@inbounds for i in axes(X1, 1)
98+
d += (X1[i, i1] - X2[i, i2])^2
99+
end
100+
101+
return d
102+
end
103+
44104
"""
45105
sum_of_squares(x, labels, centre, k)
46106
@@ -61,6 +121,17 @@ function sum_of_squares(x, labels, centre)
61121
return s
62122
end
63123

124+
function sum_of_squares(containers, x, labels, centre, r, idx)
125+
s = 0.0
126+
127+
@inbounds for j in r
128+
for i in axes(x, 1)
129+
s += (x[i, j] - centre[i, labels[j]])^2
130+
end
131+
end
132+
133+
containers.sum_of_squares[idx] = s
134+
end
64135

65136
"""
66137
Kmeans([alg::AbstractKMeansAlg,] design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)

test/test05_hamerly.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TestHamerly
22

33
using ParallelKMeans
4-
using ParallelKMeans: initialize!, double_argmax
4+
using ParallelKMeans: chunk_initialize!, double_argmax
55
using Test
66
using Random
77

@@ -11,13 +11,13 @@ using Random
1111
nrow, ncol = size(X)
1212
containers = ParallelKMeans.create_containers(Hamerly(), 3, nrow, ncol, 1)
1313

14-
ParallelKMeans.initialize!(Hamerly(), containers, centroids, X, 1)
14+
ParallelKMeans.chunk_initialize!(Hamerly(), containers, centroids, X, 1:ncol, 1)
1515
@test containers.lb == [18.0, 20.0, 5.0, 5.0]
1616
@test containers.ub == [0.0, 2.0, 0.0, 0.0]
1717
end
1818

1919
@testset "double argmax" begin
20-
@test double_argmax([0.5, 0, 0]) == (1, 2)
20+
@test double_argmax([0.5, 0, 0]) == (1, 2, 0.5, 0.0)
2121
end
2222

2323
@testset "singlethread linear separation" begin

0 commit comments

Comments
 (0)