Skip to content

Commit 775eefc

Browse files
lbollarlollar
authored and
lollar
committed
Issue JuliaStats#64: added n_init to kmeans
1 parent ea03689 commit 775eefc

File tree

2 files changed

+97
-74
lines changed

2 files changed

+97
-74
lines changed

src/kmeans.jl

Lines changed: 94 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ const _kmeans_default_init = :kmpp
1717
const _kmeans_default_maxiter = 100
1818
const _kmeans_default_tol = 1.0e-6
1919
const _kmeans_default_display = :none
20+
const _kmeans_default_n_init = 10
2021

2122
function kmeans!{T<:AbstractFloat}(X::Matrix{T}, centers::Matrix{T};
2223
weights=nothing,
2324
maxiter::Integer=_kmeans_default_maxiter,
2425
tol::Real=_kmeans_default_tol,
2526
display::Symbol=_kmeans_default_display)
27+
2628

2729
m, n = size(X)
2830
m2, k = size(centers)
@@ -43,18 +45,37 @@ function kmeans(X::Matrix, k::Int;
4345
weights=nothing,
4446
init=_kmeans_default_init,
4547
maxiter::Integer=_kmeans_default_maxiter,
48+
n_init::Integer=_kmeans_default_n_init,
4649
tol::Real=_kmeans_default_tol,
4750
display::Symbol=_kmeans_default_display)
51+
4852

4953
m, n = size(X)
5054
(2 <= k < n) || error("k must have 2 <= k < n.")
51-
iseeds = initseeds(init, X, k)
52-
centers = copyseeds(X, iseeds)
53-
kmeans!(X, centers;
54-
weights=weights,
55-
maxiter=maxiter,
56-
tol=tol,
57-
display=display)
55+
n_init > 0 || error("n_init must be greater than 0")
56+
57+
lowestcost::Float64 = Inf
58+
local bestresult::KmeansResult
59+
60+
for i = 1:n_init
61+
62+
iseeds = initseeds(init, X, k)
63+
centers = copyseeds(X, iseeds)
64+
result = kmeans!(X, centers;
65+
weights=weights,
66+
maxiter=maxiter,
67+
tol=tol,
68+
display=display)
69+
70+
if result.totalcost < lowestcost
71+
lowestcost = result.totalcost
72+
bestresult = result
73+
end
74+
75+
end
76+
77+
return bestresult
78+
5879
end
5980

6081
#### Core implementation
@@ -72,86 +93,88 @@ function _kmeans!{T<:AbstractFloat}(
7293
tol::Real, # in: tolerance of change at convergence
7394
displevel::Int) # in: the level of display
7495

75-
# initialize
76-
77-
k = size(centers, 2)
78-
to_update = Array(Bool, k) # indicators of whether a center needs to be updated
79-
unused = Int[]
80-
num_affected::Int = k # number of centers, to which the distances need to be recomputed
81-
82-
dmat = pairwise(SqEuclidean(), centers, x)
83-
dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
84-
update_assignments!(dmat, true, assignments, costs, counts, to_update, unused)
85-
objv = w == nothing ? sum(costs) : dot(w, costs)
86-
87-
# main loop
88-
t = 0
89-
converged = false
90-
if displevel >= 2
91-
@printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected"
92-
println("-------------------------------------------------------------")
93-
@printf("%7d %18.6e\n", t, objv)
94-
end
9596

96-
while !converged && t < maxiter
97-
t = t + 1
97+
98+
# initialize
9899

99-
# update (affected) centers
100+
k = size(centers, 2)
101+
to_update = Array(Bool, k) # indicators of whether a center needs to be updated
102+
unused = Int[]
103+
num_affected::Int = k # number of centers, to which the distances need to be recomputed
100104

101-
update_centers!(x, w, assignments, to_update, centers, cweights)
105+
dmat = pairwise(SqEuclidean(), centers, x)
106+
dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
107+
update_assignments!(dmat, true, assignments, costs, counts, to_update, unused)
108+
objv = w == nothing ? sum(costs) : dot(w, costs)
102109

103-
if !isempty(unused)
104-
repick_unused_centers(x, costs, centers, unused)
105-
end
110+
# main loop
111+
t = 0
112+
converged = false
113+
if displevel >= 2
114+
@printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected"
115+
println("-------------------------------------------------------------")
116+
@printf("%7d %18.6e\n", t, objv)
117+
end
106118

107-
# update pairwise distance matrix
119+
while !converged && t < maxiter
120+
t = t + 1
108121

109-
if !isempty(unused)
110-
to_update[unused] = true
111-
end
122+
# update (affected) centers
112123

113-
if t == 1 || num_affected > 0.75 * k
114-
pairwise!(dmat, SqEuclidean(), centers, x)
115-
else
116-
# if only a small subset is affected, only compute for that subset
117-
affected_inds = find(to_update)
118-
dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x)
119-
dmat[affected_inds, :] = dmat_p
120-
end
124+
update_centers!(x, w, assignments, to_update, centers, cweights)
121125

122-
# update assignments
126+
if !isempty(unused)
127+
repick_unused_centers(x, costs, centers, unused)
128+
end
123129

124-
update_assignments!(dmat, false, assignments, costs, counts, to_update, unused)
125-
num_affected = sum(to_update) + length(unused)
130+
# update pairwise distance matrix
126131

127-
# compute change of objective and determine convergence
132+
if !isempty(unused)
133+
to_update[unused] = true
134+
end
128135

129-
prev_objv = objv
130-
objv = w == nothing ? sum(costs) : dot(w, costs)
131-
objv_change = objv - prev_objv
136+
if t == 1 || num_affected > 0.75 * k
137+
pairwise!(dmat, SqEuclidean(), centers, x)
138+
else
139+
# if only a small subset is affected, only compute for that subset
140+
affected_inds = find(to_update)
141+
dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x)
142+
dmat[affected_inds, :] = dmat_p
143+
end
132144

133-
if objv_change > tol
134-
warn("The objective value changes towards an opposite direction")
135-
end
145+
# update assignments
136146

137-
if abs(objv_change) < tol
138-
converged = true
139-
end
147+
update_assignments!(dmat, false, assignments, costs, counts, to_update, unused)
148+
num_affected = sum(to_update) + length(unused)
140149

141-
# display iteration information (if asked)
150+
# compute change of objective and determine convergence
142151

143-
if displevel >= 2
144-
@printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected)
145-
end
146-
end
152+
prev_objv = objv
153+
objv = w == nothing ? sum(costs) : dot(w, costs)
154+
objv_change = objv - prev_objv
147155

148-
if displevel >= 1
149-
if converged
150-
println("K-means converged with $t iterations (objv = $objv)")
151-
else
152-
println("K-means terminated without convergence after $t iterations (objv = $objv)")
153-
end
154-
end
156+
if objv_change > tol
157+
warn("The objective value changes towards an opposite direction")
158+
end
159+
160+
if abs(objv_change) < tol
161+
converged = true
162+
end
163+
164+
# display iteration information (if asked)
165+
166+
if displevel >= 2
167+
@printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected)
168+
end
169+
end
170+
171+
if displevel >= 1
172+
if converged
173+
println("K-means converged with $t iterations (objv = $objv)")
174+
else
175+
println("K-means terminated without convergence after $t iterations (objv = $objv)")
176+
end
177+
end
155178

156179
return KmeansResult(centers, assignments, costs, counts, cweights,
157180
@compat(Float64(objv)), t, converged)

test/kmeans.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ k = 10
1212
x = rand(m, n)
1313

1414
# non-weighted
15-
r = kmeans(x, k; maxiter=50)
15+
r = kmeans(x, k; maxiter=50, n_init=2)
1616
@test isa(r, KmeansResult{Float64})
1717
@test size(r.centers) == (m, k)
1818
@test length(r.assignments) == n
@@ -24,7 +24,7 @@ r = kmeans(x, k; maxiter=50)
2424
@test_approx_eq sum(r.costs) r.totalcost
2525

2626
# non-weighted (float32)
27-
r = kmeans(@compat(map(Float32, x)), k; maxiter=50)
27+
r = kmeans(@compat(map(Float32, x)), k; maxiter=50, n_init=2)
2828
@test isa(r, KmeansResult{Float32})
2929
@test size(r.centers) == (m, k)
3030
@test length(r.assignments) == n
@@ -37,7 +37,7 @@ r = kmeans(@compat(map(Float32, x)), k; maxiter=50)
3737

3838
# weighted
3939
w = rand(n)
40-
r = kmeans(x, k; maxiter=50, weights=w)
40+
r = kmeans(x, k; maxiter=50, weights=w, n_init=2)
4141
@test isa(r, KmeansResult{Float64})
4242
@test size(r.centers) == (m, k)
4343
@test length(r.assignments) == n

0 commit comments

Comments
 (0)