diff --git a/DESCRIPTION b/DESCRIPTION index 4f12e0e9..8fc0d3fc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -62,6 +62,7 @@ Suggests: rsample, RSpectra, survival (>= 3.2-10), + tabnet, testthat (>= 3.0.0), TH.data, usethis (>= 1.5.0), diff --git a/NAMESPACE b/NAMESPACE index 993a1abc..aa746464 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -86,6 +86,7 @@ S3method(axe_env,terms) S3method(axe_env,train) S3method(axe_env,train.recipe) S3method(axe_env,xgb.Booster) +S3method(axe_fitted,"_tabnet_fit") S3method(axe_fitted,C5.0) S3method(axe_fitted,default) S3method(axe_fitted,earth) diff --git a/NEWS.md b/NEWS.md index d15bfb6b..1bc61eff 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # butcher (development version) +* Added butcher methods for `tabnet()` (@cregouby #226). + # butcher 0.3.0 * Julia Silge is now the maintainer (#230). diff --git a/R/tabnet_fit.R b/R/tabnet_fit.R new file mode 100644 index 00000000..f5f06c47 --- /dev/null +++ b/R/tabnet_fit.R @@ -0,0 +1,45 @@ +#' Axing a tabnet_fit. +#' +#' @inheritParams butcher +#' +#' @return Axed tabnet_fit object. +#' +#' @examplesIf rlang::is_installed("tabnet") +#' +#' # Load libraries +#' suppressWarnings(suppressMessages(library(parsnip))) +#' suppressWarnings(suppressMessages(library(rsample))) +#' +#' # Load data +#' split <- initial_split(mtcars, props = 9/10) +#' car_train <- training(split) +#' +#' # Create model and fit +#' mtcar_fit <- tabnet() %>% +#' set_mode("regression") %>% +#' set_engine("torch") +#' fit(mpg ~ ., data = car_train) +#' +#' out <- butcher(mtcar_fit, verbose = TRUE) +#' +#' @name axe-tabnet_fit +NULL + +#' Remove fitted values. +#' +#' @rdname axe-tabnet_fit +#' @export +axe_fitted._tabnet_fit <- function(x, verbose = FALSE, ...) { + old <- x + x$fit$fit <- exchange(x$fit$fit, "checkpoints", list(NULL)) + x$fit$fit$importances <- exchange(x$fit$fit$importances, "variables", list(NULL)) + x$fit$fit$importances <- exchange(x$fit$fit$importances, "importance", list(NULL)) + + add_butcher_attributes( + x, + old, + disabled = NULL, + add_class = FALSE, + verbose = verbose + ) +} diff --git a/man/axe-tabnet_fit.Rd b/man/axe-tabnet_fit.Rd new file mode 100644 index 00000000..d2a8b261 --- /dev/null +++ b/man/axe-tabnet_fit.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/tabnet_fit.R +\name{axe-tabnet_fit} +\alias{axe-tabnet_fit} +\alias{axe_fitted._tabnet_fit} +\title{Axing a tabnet_fit.} +\usage{ +\method{axe_fitted}{`_tabnet_fit`}(x, verbose = FALSE, ...) +} +\arguments{ +\item{x}{A model object.} + +\item{verbose}{Print information each time an axe method is executed. +Notes how much memory is released and what functions are +disabled. Default is \code{FALSE}.} + +\item{...}{Any additional arguments related to axing.} +} +\value{ +Axed tabnet_fit object. +} +\description{ +Axing a tabnet_fit. + +Remove fitted values. +} +\examples{ +\dontshow{if (rlang::is_installed("tabnet")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} + +# Load libraries +suppressWarnings(suppressMessages(library(parsnip))) +suppressWarnings(suppressMessages(library(rsample))) + +# Load data +split <- initial_split(mtcars, props = 9/10) +car_train <- training(split) + +# Create model and fit +mtcar_fit <- tabnet() \%>\% + set_mode("regression") \%>\% + set_engine("torch") + fit(mpg ~ ., data = car_train) + +out <- butcher(mtcar_fit, verbose = TRUE) +\dontshow{\}) # examplesIf} +} diff --git a/tests/testthat/test-tabnet_fit.R b/tests/testthat/test-tabnet_fit.R new file mode 100644 index 00000000..c4de0a68 --- /dev/null +++ b/tests/testthat/test-tabnet_fit.R @@ -0,0 +1,41 @@ +test_that("tabnet_fit + axe_fitted() works", { + skip_on_cran() + skip_if_not_installed("tabnet") + suppressPackageStartupMessages(library(parsnip)) + # Create model and fit + tabnet_fit <- tabnet::tabnet(epochs = 10) %>% + set_mode("regression") %>% + set_engine("torch") %>% + fit(mpg ~ ., data = mtcars) + + expect_error(axed_out <- axe_fitted(tabnet_fit, verbose = TRUE), NA) + expect_lt(lobstr::obj_size(axed_out),lobstr::obj_size(tabnet_fit)) +}) + +test_that("tabnet_fit + butcher() works", { + skip_on_cran() + skip_if_not_installed("tabnet") + suppressPackageStartupMessages(library(parsnip)) + # Create model and fit + tabnet_fit <- tabnet::tabnet(epochs = 10) %>% + set_mode("regression") %>% + set_engine("torch") %>% + fit(mpg ~ ., data = mtcars) + + expect_error(tabnet_out <- butcher(tabnet_fit, verbose = TRUE), NA) +}) + +test_that("tabnet_fit + predict() works", { + skip_on_cran() + skip_if_not_installed("tabnet") + suppressPackageStartupMessages(library(parsnip)) + # Create model and fit + tabnet_fit <- tabnet::tabnet(epochs = 10) %>% + set_mode("regression") %>% + set_engine("torch") %>% + fit(mpg ~ ., data = mtcars) + + tabnet_out <- butcher(tabnet_fit, verbose = TRUE) + new_data <- as.matrix(mtcars[1:3, 2:11]) + expect_equal(predict(tabnet_out,new_data), predict(tabnet_fit, new_data)) +})