@@ -17,12 +17,14 @@ const _kmeans_default_init = :kmpp
17
17
const _kmeans_default_maxiter = 100
18
18
const _kmeans_default_tol = 1.0e-6
19
19
const _kmeans_default_display = :none
20
+ const _kmeans_default_n_init = 10
20
21
21
22
function kmeans! {T<:AbstractFloat} (X:: Matrix{T} , centers:: Matrix{T} ;
22
23
weights= nothing ,
23
24
maxiter:: Integer = _kmeans_default_maxiter,
24
25
tol:: Real = _kmeans_default_tol,
25
26
display:: Symbol = _kmeans_default_display)
27
+
26
28
27
29
m, n = size (X)
28
30
m2, k = size (centers)
@@ -43,18 +45,37 @@ function kmeans(X::Matrix, k::Int;
43
45
weights= nothing ,
44
46
init= _kmeans_default_init,
45
47
maxiter:: Integer = _kmeans_default_maxiter,
48
+ n_init:: Integer = _kmeans_default_n_init,
46
49
tol:: Real = _kmeans_default_tol,
47
50
display:: Symbol = _kmeans_default_display)
51
+
48
52
49
53
m, n = size (X)
50
54
(2 <= k < n) || error (" k must have 2 <= k < n." )
51
- iseeds = initseeds (init, X, k)
52
- centers = copyseeds (X, iseeds)
53
- kmeans! (X, centers;
54
- weights= weights,
55
- maxiter= maxiter,
56
- tol= tol,
57
- display= display)
55
+ n_init > 0 || error (" n_init must be greater than 0" )
56
+
57
+ lowestcost:: Float64 = Inf
58
+ local bestresult:: KmeansResult
59
+
60
+ for i = 1 : n_init
61
+
62
+ iseeds = initseeds (init, X, k)
63
+ centers = copyseeds (X, iseeds)
64
+ result = kmeans! (X, centers;
65
+ weights= weights,
66
+ maxiter= maxiter,
67
+ tol= tol,
68
+ display= display)
69
+
70
+ if result. totalcost < lowestcost
71
+ lowestcost = result. totalcost
72
+ bestresult = result
73
+ end
74
+
75
+ end
76
+
77
+ return bestresult
78
+
58
79
end
59
80
60
81
# ### Core implementation
@@ -72,86 +93,88 @@ function _kmeans!{T<:AbstractFloat}(
72
93
tol:: Real , # in: tolerance of change at convergence
73
94
displevel:: Int ) # in: the level of display
74
95
75
- # initialize
76
-
77
- k = size (centers, 2 )
78
- to_update = Array (Bool, k) # indicators of whether a center needs to be updated
79
- unused = Int[]
80
- num_affected:: Int = k # number of centers, to which the distances need to be recomputed
81
-
82
- dmat = pairwise (SqEuclidean (), centers, x)
83
- dmat = convert (Array{T}, dmat) # Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
84
- update_assignments! (dmat, true , assignments, costs, counts, to_update, unused)
85
- objv = w == nothing ? sum (costs) : dot (w, costs)
86
-
87
- # main loop
88
- t = 0
89
- converged = false
90
- if displevel >= 2
91
- @printf " %7s %18s %18s | %8s \n " " Iters" " objv" " objv-change" " affected"
92
- println (" -------------------------------------------------------------" )
93
- @printf (" %7d %18.6e\n " , t, objv)
94
- end
95
96
96
- while ! converged && t < maxiter
97
- t = t + 1
97
+
98
+ # initialize
98
99
99
- # update (affected) centers
100
+ k = size (centers, 2 )
101
+ to_update = Array (Bool, k) # indicators of whether a center needs to be updated
102
+ unused = Int[]
103
+ num_affected:: Int = k # number of centers, to which the distances need to be recomputed
100
104
101
- update_centers! (x, w, assignments, to_update, centers, cweights)
105
+ dmat = pairwise (SqEuclidean (), centers, x)
106
+ dmat = convert (Array{T}, dmat) # Can be removed if one day Distance.result_type(SqEuclidean(), T, T) == T
107
+ update_assignments! (dmat, true , assignments, costs, counts, to_update, unused)
108
+ objv = w == nothing ? sum (costs) : dot (w, costs)
102
109
103
- if ! isempty (unused)
104
- repick_unused_centers (x, costs, centers, unused)
105
- end
110
+ # main loop
111
+ t = 0
112
+ converged = false
113
+ if displevel >= 2
114
+ @printf " %7s %18s %18s | %8s \n " " Iters" " objv" " objv-change" " affected"
115
+ println (" -------------------------------------------------------------" )
116
+ @printf (" %7d %18.6e\n " , t, objv)
117
+ end
106
118
107
- # update pairwise distance matrix
119
+ while ! converged && t < maxiter
120
+ t = t + 1
108
121
109
- if ! isempty (unused)
110
- to_update[unused] = true
111
- end
122
+ # update (affected) centers
112
123
113
- if t == 1 || num_affected > 0.75 * k
114
- pairwise! (dmat, SqEuclidean (), centers, x)
115
- else
116
- # if only a small subset is affected, only compute for that subset
117
- affected_inds = find (to_update)
118
- dmat_p = pairwise (SqEuclidean (), centers[:, affected_inds], x)
119
- dmat[affected_inds, :] = dmat_p
120
- end
124
+ update_centers! (x, w, assignments, to_update, centers, cweights)
121
125
122
- # update assignments
126
+ if ! isempty (unused)
127
+ repick_unused_centers (x, costs, centers, unused)
128
+ end
123
129
124
- update_assignments! (dmat, false , assignments, costs, counts, to_update, unused)
125
- num_affected = sum (to_update) + length (unused)
130
+ # update pairwise distance matrix
126
131
127
- # compute change of objective and determine convergence
132
+ if ! isempty (unused)
133
+ to_update[unused] = true
134
+ end
128
135
129
- prev_objv = objv
130
- objv = w == nothing ? sum (costs) : dot (w, costs)
131
- objv_change = objv - prev_objv
136
+ if t == 1 || num_affected > 0.75 * k
137
+ pairwise! (dmat, SqEuclidean (), centers, x)
138
+ else
139
+ # if only a small subset is affected, only compute for that subset
140
+ affected_inds = find (to_update)
141
+ dmat_p = pairwise (SqEuclidean (), centers[:, affected_inds], x)
142
+ dmat[affected_inds, :] = dmat_p
143
+ end
132
144
133
- if objv_change > tol
134
- warn (" The objective value changes towards an opposite direction" )
135
- end
145
+ # update assignments
136
146
137
- if abs (objv_change) < tol
138
- converged = true
139
- end
147
+ update_assignments! (dmat, false , assignments, costs, counts, to_update, unused)
148
+ num_affected = sum (to_update) + length (unused)
140
149
141
- # display iteration information (if asked)
150
+ # compute change of objective and determine convergence
142
151
143
- if displevel >= 2
144
- @printf (" %7d %18.6e %18.6e | %8d\n " , t, objv, objv_change, num_affected)
145
- end
146
- end
152
+ prev_objv = objv
153
+ objv = w == nothing ? sum (costs) : dot (w, costs)
154
+ objv_change = objv - prev_objv
147
155
148
- if displevel >= 1
149
- if converged
150
- println (" K-means converged with $t iterations (objv = $objv )" )
151
- else
152
- println (" K-means terminated without convergence after $t iterations (objv = $objv )" )
153
- end
154
- end
156
+ if objv_change > tol
157
+ warn (" The objective value changes towards an opposite direction" )
158
+ end
159
+
160
+ if abs (objv_change) < tol
161
+ converged = true
162
+ end
163
+
164
+ # display iteration information (if asked)
165
+
166
+ if displevel >= 2
167
+ @printf (" %7d %18.6e %18.6e | %8d\n " , t, objv, objv_change, num_affected)
168
+ end
169
+ end
170
+
171
+ if displevel >= 1
172
+ if converged
173
+ println (" K-means converged with $t iterations (objv = $objv )" )
174
+ else
175
+ println (" K-means terminated without convergence after $t iterations (objv = $objv )" )
176
+ end
177
+ end
155
178
156
179
return KmeansResult (centers, assignments, costs, counts, cweights,
157
180
@compat (Float64 (objv)), t, converged)
0 commit comments