Skip to content

Commit 2f9313a

Browse files
Merge pull request #199 from tidymodels/use-philentropy
2 parents c9edff9 + 8542615 commit 2f9313a

18 files changed

+158
-41
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ Imports:
2626
hardhat (>= 1.0.0),
2727
modelenv (>= 0.2.0.9000),
2828
parsnip (>= 1.0.2),
29+
philentropy (>= 0.9.0),
2930
prettyunits (>= 1.1.0),
30-
Rfast (>= 2.0.6),
3131
rlang (>= 1.0.6),
3232
rsample (>= 1.0.0),
3333
stats,

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# tidyclust (development version)
22

3+
* The philentropy package is now used to calculate distances rather than Rfast. (#199)
4+
35
# tidyclust 0.2.3
46

57
* Update to fix revdep issue for clustMixType. (#190)

R/extract_fit_summary.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,27 @@ extract_fit_summary.hclust <- function(object, ...) {
167167
sse_within_total_total <- map2_dbl(
168168
by_clust$data,
169169
seq_len(n_clust),
170-
~sum(Rfast::dista(centroids[.y, ], .x))
170+
~sum(
171+
philentropy::dist_many_many(
172+
as.matrix(centroids[.y, ]),
173+
as.matrix(.x),
174+
method = "euclidean"
175+
)
176+
)
171177
)
172178

173179
list(
174180
cluster_names = unique(clusts),
175181
centroids = centroids,
176182
n_members = unname(as.integer(table(clusts))),
177183
sse_within_total_total = sse_within_total_total,
178-
sse_total = sum(Rfast::dista(t(overall_centroid), training_data)),
184+
sse_total = sum(
185+
philentropy::dist_many_many(
186+
t(overall_centroid),
187+
as.matrix(training_data),
188+
method = "euclidean"
189+
)
190+
),
179191
orig_labels = NULL,
180192
cluster_assignments = clusts
181193
)

R/hier_clust.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,11 @@ translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) {
193193
num_clusters = NULL,
194194
cut_height = NULL,
195195
linkage_method = NULL,
196-
dist_fun = Rfast::Dist
196+
dist_fun = philentropy::distance
197197
) {
198-
dmat <- dist_fun(x)
198+
suppressMessages(
199+
dmat <- dist_fun(x)
200+
)
199201
res <- stats::hclust(stats::as.dist(dmat), method = linkage_method)
200202
attr(res, "num_clusters") <- num_clusters
201203
attr(res, "cut_height") <- cut_height

R/metric-helpers.R

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ prep_data_dist <- function(
1212
object,
1313
new_data = NULL,
1414
dists = NULL,
15-
dist_fun = Rfast::Dist
15+
dist_fun = philentropy::distance
1616
) {
1717
# Sihouettes requires a distance matrix
1818
if (is.null(new_data) && is.null(dists)) {
@@ -46,7 +46,9 @@ prep_data_dist <- function(
4646

4747
# Calculate distances including optionally supplied params
4848
if (is.null(dists)) {
49-
dists <- dist_fun(new_data)
49+
suppressMessages(
50+
dists <- dist_fun(new_data)
51+
)
5052
}
5153

5254
return(
@@ -63,11 +65,20 @@ prep_data_dist <- function(
6365
#' @param new_data A data frame
6466
#' @param centroids A data frame where each row is a centroid.
6567
#' @param dist_fun A function for computing matrix-to-matrix distances. Defaults
66-
#' to `Rfast::dista()`
67-
get_centroid_dists <- function(new_data, centroids, dist_fun = Rfast::dista) {
68+
#' to
69+
#' `function(x, y) philentropy::dist_many_many(x, y, method = "euclidean")`.
70+
get_centroid_dists <- function(
71+
new_data,
72+
centroids,
73+
dist_fun = function(x, y) {
74+
philentropy::dist_many_many(x, y, method = "euclidean")
75+
}
76+
) {
6877
if (ncol(new_data) != ncol(centroids)) {
6978
rlang::abort("Centroids must have same columns as data.")
7079
}
7180

72-
dist_fun(centroids, new_data)
81+
suppressMessages(
82+
dist_fun(as.matrix(centroids), as.matrix(new_data))
83+
)
7384
}

R/metric-silhouette.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ silhouette <- function(
2727
object,
2828
new_data = NULL,
2929
dists = NULL,
30-
dist_fun = Rfast::Dist
30+
dist_fun = philentropy::distance
3131
) {
3232
if (inherits(object, "cluster_spec")) {
3333
rlang::abort(
@@ -126,7 +126,7 @@ silhouette_avg.cluster_fit <- function(
126126
...
127127
) {
128128
if (is.null(dist_fun)) {
129-
dist_fun <- Rfast::Dist
129+
dist_fun <- philentropy::distance
130130
}
131131

132132
res <- silhouette_avg_impl(object, new_data, dists, dist_fun, ...)
@@ -148,7 +148,7 @@ silhouette_avg_vec <- function(
148148
object,
149149
new_data = NULL,
150150
dists = NULL,
151-
dist_fun = Rfast::Dist,
151+
dist_fun = philentropy::distance,
152152
...
153153
) {
154154
silhouette_avg_impl(object, new_data, dists, dist_fun, ...)
@@ -158,7 +158,7 @@ silhouette_avg_impl <- function(
158158
object,
159159
new_data = NULL,
160160
dists = NULL,
161-
dist_fun = Rfast::Dist,
161+
dist_fun = philentropy::distance,
162162
...
163163
) {
164164
mean(silhouette(object, new_data, dists, dist_fun, ...)$sil_width)

R/metric-sse.R

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
#'
2020
#' sse_within(kmeans_fit)
2121
#' @export
22-
sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) {
22+
sse_within <- function(
23+
object,
24+
new_data = NULL,
25+
dist_fun = function(x, y) {
26+
philentropy::dist_many_many(x, y, method = "euclidean")
27+
}
28+
) {
2329
if (inherits(object, "cluster_spec")) {
2430
rlang::abort(
2531
paste(
@@ -43,7 +49,12 @@ sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) {
4349
n_members = summ$n_members
4450
)
4551
} else {
46-
dist_to_centroids <- dist_fun(summ$centroids, new_data)
52+
suppressMessages(
53+
dist_to_centroids <- dist_fun(
54+
as.matrix(summ$centroids),
55+
as.matrix(new_data)
56+
)
57+
)
4758

4859
res <- dist_to_centroids %>%
4960
tibble::as_tibble(.name_repair = "minimal") %>%
@@ -121,7 +132,9 @@ sse_within_total.cluster_fit <- function(
121132
...
122133
) {
123134
if (is.null(dist_fun)) {
124-
dist_fun <- Rfast::dista
135+
dist_fun <- function(x, y) {
136+
philentropy::dist_many_many(x, y, method = "euclidean")
137+
}
125138
}
126139

127140
res <- sse_within_total_impl(object, new_data, dist_fun, ...)
@@ -142,7 +155,9 @@ sse_within_total.workflow <- sse_within_total.cluster_fit
142155
sse_within_total_vec <- function(
143156
object,
144157
new_data = NULL,
145-
dist_fun = Rfast::dista,
158+
dist_fun = function(x, y) {
159+
philentropy::dist_many_many(x, y, method = "euclidean")
160+
},
146161
...
147162
) {
148163
sse_within_total_impl(object, new_data, dist_fun, ...)
@@ -151,7 +166,9 @@ sse_within_total_vec <- function(
151166
sse_within_total_impl <- function(
152167
object,
153168
new_data = NULL,
154-
dist_fun = Rfast::dista,
169+
dist_fun = function(x, y) {
170+
philentropy::dist_many_many(x, y, method = "euclidean")
171+
},
155172
...
156173
) {
157174
sum(sse_within(object, new_data, dist_fun, ...)$wss, na.rm = TRUE)
@@ -208,7 +225,9 @@ sse_total.cluster_fit <- function(
208225
...
209226
) {
210227
if (is.null(dist_fun)) {
211-
dist_fun <- Rfast::dista
228+
dist_fun <- function(x, y) {
229+
philentropy::dist_many_many(x, y, method = "euclidean")
230+
}
212231
}
213232

214233
res <- sse_total_impl(object, new_data, dist_fun, ...)
@@ -229,7 +248,9 @@ sse_total.workflow <- sse_total.cluster_fit
229248
sse_total_vec <- function(
230249
object,
231250
new_data = NULL,
232-
dist_fun = Rfast::dista,
251+
dist_fun = function(x, y) {
252+
philentropy::dist_many_many(x, y, method = "euclidean")
253+
},
233254
...
234255
) {
235256
sse_total_impl(object, new_data, dist_fun, ...)
@@ -238,7 +259,9 @@ sse_total_vec <- function(
238259
sse_total_impl <- function(
239260
object,
240261
new_data = NULL,
241-
dist_fun = Rfast::dista,
262+
dist_fun = function(x, y) {
263+
philentropy::dist_many_many(x, y, method = "euclidean")
264+
},
242265
...
243266
) {
244267
# Preprocess data before computing distances if appropriate
@@ -253,7 +276,10 @@ sse_total_impl <- function(
253276
} else {
254277
overall_mean <- colSums(summ$centroids * summ$n_members) /
255278
sum(summ$n_members)
256-
tot <- dist_fun(t(as.matrix(overall_mean)), new_data)^2 %>% sum()
279+
suppressMessages(
280+
tot <- dist_fun(t(as.matrix(overall_mean)), as.matrix(new_data))^2 %>%
281+
sum()
282+
)
257283
}
258284

259285
return(tot)
@@ -310,7 +336,9 @@ sse_ratio.cluster_fit <- function(
310336
...
311337
) {
312338
if (is.null(dist_fun)) {
313-
dist_fun <- Rfast::dista
339+
dist_fun <- function(x, y) {
340+
philentropy::dist_many_many(x, y, method = "euclidean")
341+
}
314342
}
315343
res <- sse_ratio_impl(object, new_data, dist_fun, ...)
316344

@@ -330,7 +358,9 @@ sse_ratio.workflow <- sse_ratio.cluster_fit
330358
sse_ratio_vec <- function(
331359
object,
332360
new_data = NULL,
333-
dist_fun = Rfast::dista,
361+
dist_fun = function(x, y) {
362+
philentropy::dist_many_many(x, y, method = "euclidean")
363+
},
334364
...
335365
) {
336366
sse_ratio_impl(object, new_data, dist_fun, ...)
@@ -339,7 +369,9 @@ sse_ratio_vec <- function(
339369
sse_ratio_impl <- function(
340370
object,
341371
new_data = NULL,
342-
dist_fun = Rfast::dista,
372+
dist_fun = function(x, y) {
373+
philentropy::dist_many_many(x, y, method = "euclidean")
374+
},
343375
...
344376
) {
345377
sse_within_total_vec(object, new_data, dist_fun) /

R/predict_helpers.R

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ make_predictions <- function(x, prefix, n_clusters) {
9696
)
9797

9898
# need this to be obs on rows, dist to new data on cols
99-
dists_new <- Rfast::dista(xnew = training_data, x = new_data, trans = TRUE)
99+
dists_new <- philentropy::dist_many_many(
100+
training_data,
101+
new_data,
102+
method = "euclidean"
103+
)
100104

101105
cluster_dists <- dplyr::bind_cols(data.frame(dists_new), clusters) %>%
102106
dplyr::group_by(.cluster) %>%
@@ -109,7 +113,12 @@ make_predictions <- function(x, prefix, n_clusters) {
109113
## Centroid linkage_method, dist to center
110114

111115
cluster_centers <- extract_centroids(object) %>% dplyr::select(-.cluster)
112-
dists_means <- Rfast::dista(new_data, cluster_centers)
116+
117+
dists_means <- philentropy::dist_many_many(
118+
new_data,
119+
cluster_centers,
120+
method = "euclidean"
121+
)
113122

114123
pred_clusts_num <- apply(dists_means, 1, which.min)
115124
} else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) {

man/dot-hier_clust_fit_stats.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_centroid_dists.Rd

Lines changed: 10 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)