Skip to content

Issue #64: added n_init to kmeans #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 93 additions & 71 deletions src/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const _kmeans_default_init = :kmpp
const _kmeans_default_maxiter = 100
const _kmeans_default_tol = 1.0e-6
const _kmeans_default_display = :none
const _kmeans_default_n_init = 10

function kmeans!{T<:AbstractFloat}(X::Matrix{T}, centers::Matrix{T};
weights=nothing,
Expand All @@ -43,18 +44,37 @@ function kmeans(X::Matrix, k::Int;
weights=nothing,
init=_kmeans_default_init,
maxiter::Integer=_kmeans_default_maxiter,
n_init::Integer=_kmeans_default_n_init,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that n_init comes from Python's sklearn (#64), but it doesn't sound like a best choice for me.
Maybe something like n_tries to reflect that the parameter defines how many times the algorithm, rather than some initialization procedure, is run?

Copy link
Contributor

@wildart wildart Sep 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or ntries? And wouldn't be an overkill to run 10 times? I recommend default value 1, because usually a quick partitioning is required and not necessarily best one. And, if one needs to find a best clustering, this parameter can be set to larger value explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 is what sklearn does at it sounds reasonable to me.
It isn't unusual to run 1000s of times, (that was done as the baseline for the affinity propagation paper)
If some need a quick partition they can ask for it.

The default shouldn't be so sensitive to random factors.

I think 10 strikes the right balance.
Though I could see argument for 3 or 30

tol::Real=_kmeans_default_tol,
display::Symbol=_kmeans_default_display)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last remaining extraneous newline.

m, n = size(X)
(2 <= k < n) || error("k must have 2 <= k < n.")
iseeds = initseeds(init, X, k)
centers = copyseeds(X, iseeds)
kmeans!(X, centers;
weights=weights,
maxiter=maxiter,
tol=tol,
display=display)
n_init > 0 || throw(ArgumentError("n_init must be greater than 0"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is still being indented with tabs. You can do a find-and-replace to change everything at once, if you wish; I do it in Vim like

:%s/\t/    /g


lowestcost::Float64 = Inf
local bestresult::KmeansResult

for i = 1:n_init

iseeds = initseeds(init, X, k)
centers = copyseeds(X, iseeds)
result = kmeans!(X, centers;
weights=weights,
maxiter=maxiter,
tol=tol,
display=display)

if result.totalcost < lowestcost
lowestcost = result.totalcost
bestresult = result
end

end

return bestresult

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the line breaks on lines 60, 73, and 75 for style consistency with the rest of the code.

end

#### Core implementation
Expand All @@ -72,86 +92,88 @@ function _kmeans!{T<:AbstractFloat}(
tol::Real, # in: tolerance of change at convergence
displevel::Int) # in: the level of display

# initialize

k = size(centers, 2)
to_update = Array(Bool, k) # indicators of whether a center needs to be updated
unused = Int[]
num_affected::Int = k # number of centers, to which the distances need to be recomputed

dmat = pairwise(SqEuclidean(), centers, x)
dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
update_assignments!(dmat, true, assignments, costs, counts, to_update, unused)
objv = w == nothing ? sum(costs) : dot(w, costs)

# main loop
t = 0
converged = false
if displevel >= 2
@printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected"
println("-------------------------------------------------------------")
@printf("%7d %18.6e\n", t, objv)
end

while !converged && t < maxiter
t = t + 1
# initialize

# update (affected) centers
k = size(centers, 2)
to_update = Array(Bool, k) # indicators of whether a center needs to be updated
unused = Int[]
num_affected::Int = k # number of centers, to which the distances need to be recomputed

update_centers!(x, w, assignments, to_update, centers, cweights)
dmat = pairwise(SqEuclidean(), centers, x)
dmat = convert(Array{T}, dmat) #Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
update_assignments!(dmat, true, assignments, costs, counts, to_update, unused)
objv = w == nothing ? sum(costs) : dot(w, costs)

if !isempty(unused)
repick_unused_centers(x, costs, centers, unused)
end
# main loop
t = 0
converged = false
if displevel >= 2
@printf "%7s %18s %18s | %8s \n" "Iters" "objv" "objv-change" "affected"
println("-------------------------------------------------------------")
@printf("%7d %18.6e\n", t, objv)
end

# update pairwise distance matrix
while !converged && t < maxiter
t = t + 1

if !isempty(unused)
to_update[unused] = true
end
# update (affected) centers

if t == 1 || num_affected > 0.75 * k
pairwise!(dmat, SqEuclidean(), centers, x)
else
# if only a small subset is affected, only compute for that subset
affected_inds = find(to_update)
dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x)
dmat[affected_inds, :] = dmat_p
end
update_centers!(x, w, assignments, to_update, centers, cweights)

# update assignments
if !isempty(unused)
repick_unused_centers(x, costs, centers, unused)
end

update_assignments!(dmat, false, assignments, costs, counts, to_update, unused)
num_affected = sum(to_update) + length(unused)
# update pairwise distance matrix

# compute change of objective and determine convergence
if !isempty(unused)
to_update[unused] = true
end

prev_objv = objv
objv = w == nothing ? sum(costs) : dot(w, costs)
objv_change = objv - prev_objv
if t == 1 || num_affected > 0.75 * k
pairwise!(dmat, SqEuclidean(), centers, x)
else
# if only a small subset is affected, only compute for that subset
affected_inds = find(to_update)
dmat_p = pairwise(SqEuclidean(), centers[:, affected_inds], x)
dmat[affected_inds, :] = dmat_p
end

if objv_change > tol
warn("The objective value changes towards an opposite direction")
end
# update assignments

if abs(objv_change) < tol
converged = true
end
update_assignments!(dmat, false, assignments, costs, counts, to_update, unused)
num_affected = sum(to_update) + length(unused)

# display iteration information (if asked)
# compute change of objective and determine convergence

if displevel >= 2
@printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected)
end
end
prev_objv = objv
objv = w == nothing ? sum(costs) : dot(w, costs)
objv_change = objv - prev_objv

if displevel >= 1
if converged
println("K-means converged with $t iterations (objv = $objv)")
else
println("K-means terminated without convergence after $t iterations (objv = $objv)")
end
end
if objv_change > tol
warn("The objective value changes towards an opposite direction")
end

if abs(objv_change) < tol
converged = true
end

# display iteration information (if asked)

if displevel >= 2
@printf("%7d %18.6e %18.6e | %8d\n", t, objv, objv_change, num_affected)
end
end

if displevel >= 1
if converged
println("K-means converged with $t iterations (objv = $objv)")
else
println("K-means terminated without convergence after $t iterations (objv = $objv)")
end
end

return KmeansResult(centers, assignments, costs, counts, cweights,
@compat(Float64(objv)), t, converged)
Expand Down
6 changes: 3 additions & 3 deletions test/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ k = 10
x = rand(m, n)

# non-weighted
r = kmeans(x, k; maxiter=50)
r = kmeans(x, k; maxiter=50, n_init=2)
@test isa(r, KmeansResult{Float64})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand All @@ -24,7 +24,7 @@ r = kmeans(x, k; maxiter=50)
@test_approx_eq sum(r.costs) r.totalcost

# non-weighted (float32)
r = kmeans(@compat(map(Float32, x)), k; maxiter=50)
r = kmeans(@compat(map(Float32, x)), k; maxiter=50, n_init=2)
@test isa(r, KmeansResult{Float32})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand All @@ -37,7 +37,7 @@ r = kmeans(@compat(map(Float32, x)), k; maxiter=50)

# weighted
w = rand(n)
r = kmeans(x, k; maxiter=50, weights=w)
r = kmeans(x, k; maxiter=50, weights=w, n_init=2)
@test isa(r, KmeansResult{Float64})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand Down