From 4b3b6ef07fb7f7c89e1a6baeede282df5938d96c Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Mon, 28 Mar 2022 14:57:07 +0200 Subject: [PATCH 1/2] support surv.rpart --- NAMESPACE | 2 ++ R/LearnerSurvRpart.R | 23 +++++++++++++++++++++++ man/autoplot.LearnerClassifGlmnet.Rd | 4 ++-- man/autoplot.LearnerClassifRpart.Rd | 8 ++++++-- man/autoplot.LearnerClustHierarchical.Rd | 2 +- tests/testthat/test_LearnerSurvRpart.R | 13 +++++++++++++ 6 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 R/LearnerSurvRpart.R create mode 100644 tests/testthat/test_LearnerSurvRpart.R diff --git a/NAMESPACE b/NAMESPACE index d0299bd4..4d87b050 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -15,6 +15,7 @@ S3method(autoplot,LearnerClustHierarchical) S3method(autoplot,LearnerRegrCVGlmnet) S3method(autoplot,LearnerRegrGlmnet) S3method(autoplot,LearnerRegrRpart) +S3method(autoplot,LearnerSurvRpart) S3method(autoplot,OptimInstanceSingleCrit) S3method(autoplot,PredictionClassif) S3method(autoplot,PredictionClust) @@ -41,6 +42,7 @@ S3method(plot,LearnerClassifRpart) S3method(plot,LearnerRegrCVGlmnet) S3method(plot,LearnerRegrGlmnet) S3method(plot,LearnerRegrRpart) +S3method(plot,LearnerSurvRpart) S3method(plot,PredictionClassif) S3method(plot,PredictionRegr) S3method(plot,ResampleResult) diff --git a/R/LearnerSurvRpart.R b/R/LearnerSurvRpart.R new file mode 100644 index 00000000..94d0ad88 --- /dev/null +++ b/R/LearnerSurvRpart.R @@ -0,0 +1,23 @@ +#' @export +#' @rdname autoplot.LearnerClassifRpart +autoplot.LearnerSurvRpart = function(object, ...) { # nolint + if (is.null(object$model)) { + stopf("Learner '%s' must be trained first", object$id) + } + require_namespaces(c("partykit", "ggparty")) + + target = all.vars(object$model$terms)[1L] + autoplot(partykit::as.party(object$model), ...) + + ggparty::geom_node_plot(gglist = list( + geom_boxplot(aes_string(target)), + coord_flip(), + theme(axis.ticks.x = element_blank(), axis.text.x = element_blank()) + )) + + ggparty::geom_node_label(aes(label = paste0("n=", .data[["nodesize"]])), + nudge_y = 0.03, ids = "terminal") +} + +#' @export +plot.LearnerSurvRpart = function(x, ...) { + print(autoplot(x, ...)) +} diff --git a/man/autoplot.LearnerClassifGlmnet.Rd b/man/autoplot.LearnerClassifGlmnet.Rd index 7bb78a9c..32d87fc6 100644 --- a/man/autoplot.LearnerClassifGlmnet.Rd +++ b/man/autoplot.LearnerClassifGlmnet.Rd @@ -17,8 +17,8 @@ \method{autoplot}{LearnerRegrGlmnet}(object, ...) } \arguments{ -\item{object}{(\link[mlr3learners:LearnerClassifGlmnet]{mlr3learners::LearnerClassifGlmnet} | \link[mlr3learners:LearnerRegrGlmnet]{mlr3learners::LearnerRegrGlmnet} | -\link[mlr3learners:LearnerRegrCVGlmnet]{mlr3learners::LearnerRegrCVGlmnet} | \link[mlr3learners:LearnerRegrCVGlmnet]{mlr3learners::LearnerRegrCVGlmnet}).} +\item{object}{(\link[mlr3learners:mlr_learners_classif.glmnet]{mlr3learners::LearnerClassifGlmnet} | \link[mlr3learners:mlr_learners_regr.glmnet]{mlr3learners::LearnerRegrGlmnet} | +\link[mlr3learners:mlr_learners_regr.cv_glmnet]{mlr3learners::LearnerRegrCVGlmnet} | \link[mlr3learners:mlr_learners_regr.cv_glmnet]{mlr3learners::LearnerRegrCVGlmnet}).} \item{...}{(\code{any}): Additional arguments, passed down to \code{\link[ggparty:autoplot.party]{ggparty::autoplot.party()}}.} diff --git a/man/autoplot.LearnerClassifRpart.Rd b/man/autoplot.LearnerClassifRpart.Rd index 60d0f34f..255e2c7d 100644 --- a/man/autoplot.LearnerClassifRpart.Rd +++ b/man/autoplot.LearnerClassifRpart.Rd @@ -1,16 +1,20 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/LearnerClassifRpart.R, R/LearnerRegrRpart.R +% Please edit documentation in R/LearnerClassifRpart.R, R/LearnerRegrRpart.R, +% R/LearnerSurvRpart.R \name{autoplot.LearnerClassifRpart} \alias{autoplot.LearnerClassifRpart} \alias{autoplot.LearnerRegrRpart} +\alias{autoplot.LearnerSurvRpart} \title{Plot for LearnerClassifRpart / LearnerRegrRpart} \usage{ \method{autoplot}{LearnerClassifRpart}(object, ...) \method{autoplot}{LearnerRegrRpart}(object, ...) + +\method{autoplot}{LearnerSurvRpart}(object, ...) } \arguments{ -\item{object}{(\link[mlr3:LearnerClassifRpart]{mlr3::LearnerClassifRpart} | \link[mlr3:LearnerRegrRpart]{mlr3::LearnerRegrRpart}).} +\item{object}{(\link[mlr3:mlr_learners_classif.rpart]{mlr3::LearnerClassifRpart} | \link[mlr3:mlr_learners_regr.rpart]{mlr3::LearnerRegrRpart}).} \item{...}{(\code{any}): Additional arguments, passed down to \code{\link[ggparty:autoplot.party]{ggparty::autoplot.party()}}.} diff --git a/man/autoplot.LearnerClustHierarchical.Rd b/man/autoplot.LearnerClustHierarchical.Rd index 3ef98181..9a11070d 100644 --- a/man/autoplot.LearnerClustHierarchical.Rd +++ b/man/autoplot.LearnerClustHierarchical.Rd @@ -7,7 +7,7 @@ \method{autoplot}{LearnerClustHierarchical}(object, type = "dend", ...) } \arguments{ -\item{object}{(\link[mlr3cluster:LearnerClustAgnes]{mlr3cluster::LearnerClustAgnes} | \link[mlr3cluster:LearnerClustDiana]{mlr3cluster::LearnerClustDiana} | \link[mlr3cluster:LearnerClustHclust]{mlr3cluster::LearnerClustHclust}).} +\item{object}{(\link[mlr3cluster:mlr_learners_clust.agnes]{mlr3cluster::LearnerClustAgnes} | \link[mlr3cluster:mlr_learners_clust.diana]{mlr3cluster::LearnerClustDiana} | \link[mlr3cluster:mlr_learners_clust.hclust]{mlr3cluster::LearnerClustHclust}).} \item{type}{(character(1)):\cr Type of the plot. See description.} diff --git a/tests/testthat/test_LearnerSurvRpart.R b/tests/testthat/test_LearnerSurvRpart.R new file mode 100644 index 00000000..07de5511 --- /dev/null +++ b/tests/testthat/test_LearnerSurvRpart.R @@ -0,0 +1,13 @@ +skip_if_not_installed("survival") +skip_if_not_installed("mlr3proba") +skip_if_not_installed("rpart") +skip_if_not_installed("partykit") +skip_if_not_installed("ggparty") + +test_that("autoplot.LearnerSurvRpart", { + learner = mlr3::lrn("surv.rpart")$train(mlr3::tsk("rats")) + p = autoplot(learner) + expect_true(is.ggplot(p)) + vdiffr::expect_doppelganger("learner_regr.rpart", p) +}) + From 093b6f097c5732393df3d59c9630d3c4caab1ed2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Mar 2022 12:59:03 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- man/autoplot.LearnerClassifGlmnet.Rd | 4 ++-- man/autoplot.LearnerClassifRpart.Rd | 2 +- man/autoplot.LearnerClustHierarchical.Rd | 2 +- tests/testthat/test_LearnerSurvRpart.R | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/man/autoplot.LearnerClassifGlmnet.Rd b/man/autoplot.LearnerClassifGlmnet.Rd index 32d87fc6..7bb78a9c 100644 --- a/man/autoplot.LearnerClassifGlmnet.Rd +++ b/man/autoplot.LearnerClassifGlmnet.Rd @@ -17,8 +17,8 @@ \method{autoplot}{LearnerRegrGlmnet}(object, ...) } \arguments{ -\item{object}{(\link[mlr3learners:mlr_learners_classif.glmnet]{mlr3learners::LearnerClassifGlmnet} | \link[mlr3learners:mlr_learners_regr.glmnet]{mlr3learners::LearnerRegrGlmnet} | -\link[mlr3learners:mlr_learners_regr.cv_glmnet]{mlr3learners::LearnerRegrCVGlmnet} | \link[mlr3learners:mlr_learners_regr.cv_glmnet]{mlr3learners::LearnerRegrCVGlmnet}).} +\item{object}{(\link[mlr3learners:LearnerClassifGlmnet]{mlr3learners::LearnerClassifGlmnet} | \link[mlr3learners:LearnerRegrGlmnet]{mlr3learners::LearnerRegrGlmnet} | +\link[mlr3learners:LearnerRegrCVGlmnet]{mlr3learners::LearnerRegrCVGlmnet} | \link[mlr3learners:LearnerRegrCVGlmnet]{mlr3learners::LearnerRegrCVGlmnet}).} \item{...}{(\code{any}): Additional arguments, passed down to \code{\link[ggparty:autoplot.party]{ggparty::autoplot.party()}}.} diff --git a/man/autoplot.LearnerClassifRpart.Rd b/man/autoplot.LearnerClassifRpart.Rd index 255e2c7d..221faf9f 100644 --- a/man/autoplot.LearnerClassifRpart.Rd +++ b/man/autoplot.LearnerClassifRpart.Rd @@ -14,7 +14,7 @@ \method{autoplot}{LearnerSurvRpart}(object, ...) } \arguments{ -\item{object}{(\link[mlr3:mlr_learners_classif.rpart]{mlr3::LearnerClassifRpart} | \link[mlr3:mlr_learners_regr.rpart]{mlr3::LearnerRegrRpart}).} +\item{object}{(\link[mlr3:LearnerClassifRpart]{mlr3::LearnerClassifRpart} | \link[mlr3:LearnerRegrRpart]{mlr3::LearnerRegrRpart}).} \item{...}{(\code{any}): Additional arguments, passed down to \code{\link[ggparty:autoplot.party]{ggparty::autoplot.party()}}.} diff --git a/man/autoplot.LearnerClustHierarchical.Rd b/man/autoplot.LearnerClustHierarchical.Rd index 9a11070d..3ef98181 100644 --- a/man/autoplot.LearnerClustHierarchical.Rd +++ b/man/autoplot.LearnerClustHierarchical.Rd @@ -7,7 +7,7 @@ \method{autoplot}{LearnerClustHierarchical}(object, type = "dend", ...) } \arguments{ -\item{object}{(\link[mlr3cluster:mlr_learners_clust.agnes]{mlr3cluster::LearnerClustAgnes} | \link[mlr3cluster:mlr_learners_clust.diana]{mlr3cluster::LearnerClustDiana} | \link[mlr3cluster:mlr_learners_clust.hclust]{mlr3cluster::LearnerClustHclust}).} +\item{object}{(\link[mlr3cluster:LearnerClustAgnes]{mlr3cluster::LearnerClustAgnes} | \link[mlr3cluster:LearnerClustDiana]{mlr3cluster::LearnerClustDiana} | \link[mlr3cluster:LearnerClustHclust]{mlr3cluster::LearnerClustHclust}).} \item{type}{(character(1)):\cr Type of the plot. See description.} diff --git a/tests/testthat/test_LearnerSurvRpart.R b/tests/testthat/test_LearnerSurvRpart.R index 07de5511..64ea64aa 100644 --- a/tests/testthat/test_LearnerSurvRpart.R +++ b/tests/testthat/test_LearnerSurvRpart.R @@ -10,4 +10,3 @@ test_that("autoplot.LearnerSurvRpart", { expect_true(is.ggplot(p)) vdiffr::expect_doppelganger("learner_regr.rpart", p) }) -