Skip to content

Commit f5a2333

Browse files
committed
switch from Rfast to philentropy
1 parent da441c8 commit f5a2333

File tree

12 files changed

+134
-29
lines changed

12 files changed

+134
-29
lines changed

R/extract_fit_summary.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,27 @@ extract_fit_summary.hclust <- function(object, ...) {
167167
sse_within_total_total <- map2_dbl(
168168
by_clust$data,
169169
seq_len(n_clust),
170-
~sum(Rfast::dista(centroids[.y, ], .x))
170+
~sum(
171+
philentropy::dist_many_many(
172+
as.matrix(centroids[.y, ]),
173+
as.matrix(.x),
174+
method = "euclidean"
175+
)
176+
)
171177
)
172178

173179
list(
174180
cluster_names = unique(clusts),
175181
centroids = centroids,
176182
n_members = unname(as.integer(table(clusts))),
177183
sse_within_total_total = sse_within_total_total,
178-
sse_total = sum(Rfast::dista(t(overall_centroid), training_data)),
184+
sse_total = sum(
185+
philentropy::dist_many_many(
186+
t(overall_centroid),
187+
as.matrix(training_data),
188+
method = "euclidean"
189+
)
190+
),
179191
orig_labels = NULL,
180192
cluster_assignments = clusts
181193
)

R/metric-helpers.R

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,20 @@ prep_data_dist <- function(
6565
#' @param new_data A data frame
6666
#' @param centroids A data frame where each row is a centroid.
6767
#' @param dist_fun A function for computing matrix-to-matrix distances. Defaults
68-
#' to `Rfast::dista()`
69-
get_centroid_dists <- function(new_data, centroids, dist_fun = Rfast::dista) {
68+
#' to
69+
#' `function(x, y) philentropy::dist_many_many(x, y, method = "euclidean")`.
70+
get_centroid_dists <- function(
71+
new_data,
72+
centroids,
73+
dist_fun = function(x, y) {
74+
philentropy::dist_many_many(x, y, method = "euclidean")
75+
}
76+
) {
7077
if (ncol(new_data) != ncol(centroids)) {
7178
rlang::abort("Centroids must have same columns as data.")
7279
}
7380

7481
suppressMessages(
75-
dist_fun(centroids, new_data)
82+
dist_fun(as.matrix(centroids), as.matrix(new_data))
7683
)
7784
}

R/metric-sse.R

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
#'
2020
#' sse_within(kmeans_fit)
2121
#' @export
22-
sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) {
22+
sse_within <- function(
23+
object,
24+
new_data = NULL,
25+
dist_fun = function(x, y) {
26+
philentropy::dist_many_many(x, y, method = "euclidean")
27+
}
28+
) {
2329
if (inherits(object, "cluster_spec")) {
2430
rlang::abort(
2531
paste(
@@ -44,7 +50,10 @@ sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) {
4450
)
4551
} else {
4652
suppressMessages(
47-
dist_to_centroids <- dist_fun(summ$centroids, new_data)
53+
dist_to_centroids <- dist_fun(
54+
as.matrix(summ$centroids),
55+
as.matrix(new_data)
56+
)
4857
)
4958

5059
res <- dist_to_centroids %>%
@@ -123,7 +132,9 @@ sse_within_total.cluster_fit <- function(
123132
...
124133
) {
125134
if (is.null(dist_fun)) {
126-
dist_fun <- Rfast::dista
135+
dist_fun <- function(x, y) {
136+
philentropy::dist_many_many(x, y, method = "euclidean")
137+
}
127138
}
128139

129140
res <- sse_within_total_impl(object, new_data, dist_fun, ...)
@@ -144,7 +155,9 @@ sse_within_total.workflow <- sse_within_total.cluster_fit
144155
sse_within_total_vec <- function(
145156
object,
146157
new_data = NULL,
147-
dist_fun = Rfast::dista,
158+
dist_fun = function(x, y) {
159+
philentropy::dist_many_many(x, y, method = "euclidean")
160+
},
148161
...
149162
) {
150163
sse_within_total_impl(object, new_data, dist_fun, ...)
@@ -153,7 +166,9 @@ sse_within_total_vec <- function(
153166
sse_within_total_impl <- function(
154167
object,
155168
new_data = NULL,
156-
dist_fun = Rfast::dista,
169+
dist_fun = function(x, y) {
170+
philentropy::dist_many_many(x, y, method = "euclidean")
171+
},
157172
...
158173
) {
159174
sum(sse_within(object, new_data, dist_fun, ...)$wss, na.rm = TRUE)
@@ -210,7 +225,9 @@ sse_total.cluster_fit <- function(
210225
...
211226
) {
212227
if (is.null(dist_fun)) {
213-
dist_fun <- Rfast::dista
228+
dist_fun <- function(x, y) {
229+
philentropy::dist_many_many(x, y, method = "euclidean")
230+
}
214231
}
215232

216233
res <- sse_total_impl(object, new_data, dist_fun, ...)
@@ -231,7 +248,9 @@ sse_total.workflow <- sse_total.cluster_fit
231248
sse_total_vec <- function(
232249
object,
233250
new_data = NULL,
234-
dist_fun = Rfast::dista,
251+
dist_fun = function(x, y) {
252+
philentropy::dist_many_many(x, y, method = "euclidean")
253+
},
235254
...
236255
) {
237256
sse_total_impl(object, new_data, dist_fun, ...)
@@ -240,7 +259,9 @@ sse_total_vec <- function(
240259
sse_total_impl <- function(
241260
object,
242261
new_data = NULL,
243-
dist_fun = Rfast::dista,
262+
dist_fun = function(x, y) {
263+
philentropy::dist_many_many(x, y, method = "euclidean")
264+
},
244265
...
245266
) {
246267
# Preprocess data before computing distances if appropriate
@@ -256,7 +277,8 @@ sse_total_impl <- function(
256277
overall_mean <- colSums(summ$centroids * summ$n_members) /
257278
sum(summ$n_members)
258279
suppressMessages(
259-
tot <- dist_fun(t(as.matrix(overall_mean)), new_data)^2 %>% sum()
280+
tot <- dist_fun(t(as.matrix(overall_mean)), as.matrix(new_data))^2 %>%
281+
sum()
260282
)
261283
}
262284

@@ -314,7 +336,9 @@ sse_ratio.cluster_fit <- function(
314336
...
315337
) {
316338
if (is.null(dist_fun)) {
317-
dist_fun <- Rfast::dista
339+
dist_fun <- function(x, y) {
340+
philentropy::dist_many_many(x, y, method = "euclidean")
341+
}
318342
}
319343
res <- sse_ratio_impl(object, new_data, dist_fun, ...)
320344

@@ -334,7 +358,9 @@ sse_ratio.workflow <- sse_ratio.cluster_fit
334358
sse_ratio_vec <- function(
335359
object,
336360
new_data = NULL,
337-
dist_fun = Rfast::dista,
361+
dist_fun = function(x, y) {
362+
philentropy::dist_many_many(x, y, method = "euclidean")
363+
},
338364
...
339365
) {
340366
sse_ratio_impl(object, new_data, dist_fun, ...)
@@ -343,7 +369,9 @@ sse_ratio_vec <- function(
343369
sse_ratio_impl <- function(
344370
object,
345371
new_data = NULL,
346-
dist_fun = Rfast::dista,
372+
dist_fun = function(x, y) {
373+
philentropy::dist_many_many(x, y, method = "euclidean")
374+
},
347375
...
348376
) {
349377
sse_within_total_vec(object, new_data, dist_fun) /

R/predict_helpers.R

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ make_predictions <- function(x, prefix, n_clusters) {
9696
)
9797

9898
# need this to be obs on rows, dist to new data on cols
99-
dists_new <- Rfast::dista(xnew = training_data, x = new_data, trans = TRUE)
99+
dists_new <- philentropy::dist_many_many(
100+
training_data,
101+
new_data,
102+
method = "euclidean"
103+
)
100104

101105
cluster_dists <- dplyr::bind_cols(data.frame(dists_new), clusters) %>%
102106
dplyr::group_by(.cluster) %>%
@@ -109,7 +113,12 @@ make_predictions <- function(x, prefix, n_clusters) {
109113
## Centroid linkage_method, dist to center
110114

111115
cluster_centers <- extract_centroids(object) %>% dplyr::select(-.cluster)
112-
dists_means <- Rfast::dista(new_data, cluster_centers)
116+
117+
dists_means <- philentropy::dist_many_many(
118+
new_data,
119+
cluster_centers,
120+
method = "euclidean"
121+
)
113122

114123
pred_clusts_num <- apply(dists_means, 1, which.min)
115124
} else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) {

man/get_centroid_dists.Rd

Lines changed: 10 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/prep_data_dist.Rd

Lines changed: 6 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/silhouette.Rd

Lines changed: 6 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sse_ratio.Rd

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sse_total.Rd

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sse_within.Rd

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)