diff --git a/R/evaluation.R b/R/evaluation.R index 709ecd0..fc1b589 100644 --- a/R/evaluation.R +++ b/R/evaluation.R @@ -34,7 +34,7 @@ evaluate_survdnn <- function(model, data <- if (is.null(newdata)) model$data else newdata n_before <- nrow(data) - # Build model frame first with explicit NA policy (so y aligns with predictions) + # build model frame first with explicit NA policy mf <- model.frame( model$formula, data = data, @@ -50,10 +50,10 @@ evaluate_survdnn <- function(model, y <- model.response(mf) if (!inherits(y, "Surv")) stop("The response must be a 'Surv' object.", call. = FALSE) - # Predict on the filtered mf to keep row alignment + # predict on the filtered mf to keep row alignment sp_matrix <- predict(model, newdata = mf, times = times, type = "survival") - purrr::map_dfr(metrics, function(metric) { + purrr::map_dfr(metrics, function(metric) { ## to replace with fmap from functionals package if (metric == "brier" && length(times) > 1) { tibble::tibble( metric = "brier", diff --git a/R/tune_survdnn.R b/R/tune_survdnn.R index 46cbbaa..d0ed91d 100644 --- a/R/tune_survdnn.R +++ b/R/tune_survdnn.R @@ -83,7 +83,7 @@ tune_survdnn <- function(formula, summary_tbl <- summarize_tune_survdnn(all_results, by_time = FALSE) - ## Select best hyperparameters + ## select best hyperparameters primary_metric <- metrics[1] best_row_all <- all_results |> @@ -99,7 +99,7 @@ tune_survdnn <- function(formula, stop("No valid configuration found for primary metric: ", primary_metric, call. = FALSE) } - ## Refitting the best model + ## refitting the best model if (refit) { message("Refitting best model on full data...") best_model <- survdnn( diff --git a/R/zzz.R b/R/zzz.R index 741f09c..460d1fc 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -9,14 +9,11 @@ "fold", "metric", "value", "id", "time", "surv", "group", "mean_surv", "n", "se", "hidden", "lr", "activation", "epochs", "loss_name", ".loss_fn" )) - - # IMPORTANT: never load or probe torch here (because CRAN/Windows may segfault). - # No torch checks, no tensor creation on load. + # never load or probe torch here (because CRAN/Windows may segfault) } ## handles user-facing messaging .onAttach <- function(libname, pkgname) { - # Do NOT load torch or call torch::torch_is_installed() here. # friendly hint that doesn't load the namespace: torch_pkg_present <- nzchar(system.file(package = "torch")) @@ -51,7 +48,7 @@ survdnn_set_seed <- function(.seed = NULL) { -## Internal utility to choose a torch device for survdnn +## internal utility to choose a torch device for survdnn survdnn_get_device <- function(.device = c("auto", "cpu", "cuda")) { .device <- match.arg(.device) diff --git a/README.Rmd b/README.Rmd index 14faaf1..16bdd2a 100644 --- a/README.Rmd +++ b/README.Rmd @@ -4,7 +4,7 @@ output: github_document # survdnn -> Deep Neural Networks for Survival Analysis using [torch](https://torch.mlverse.org/) +> Deep Neural Networks for Survival Analysis using [R torch](https://torch.mlverse.org/) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) [![R-CMD-check](https://github.com/ielbadisy/survdnn/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/ielbadisy/survdnn/actions/workflows/R-CMD-check.yaml) @@ -43,7 +43,7 @@ A methodological paper describing the design, implementation, and evaluation of ```{r, eval = FALSE} # Install from CRAN -install.packages("surdnn") +install.packages("survdnn") # Install from GitHub diff --git a/README.html b/README.html deleted file mode 100644 index 346a150..0000000 --- a/README.html +++ /dev/null @@ -1,906 +0,0 @@ - - - - - - - - - - - - - - - - - - - -

survdnn -

-
-

Deep Neural Networks for Survival Analysis using torch

-
-

License: MIT
-R-CMD-check

-

About

-

survdnn implements neural network-based models for -right-censored survival analysis using the native torch -backend in R. It supports multiple loss functions including Cox partial -likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives, -as well as time-dependent extension such as Cox-Time. The package -provides a formula interface, supports model evaluation using -time-dependent metrics (C-index, Brier score, IBS), cross-validation, -and hyperparameter tuning.

-

Review status

-

A methodological paper describing the design, implementation, and -evaluation of survdnn is currently under review at The -R Journal.

-

Main features

- -

Installation

-
# Install from CRAN
-install.packages("surdnn")
-
-
-# Install from GitHub
-install.packages("remotes")
-remotes::install_github("ielbadisy/survdnn")
-
-# Or clone and install locally
-git clone https://github.com/ielbadisy/survdnn.git
-setwd("survdnn")
-devtools::install()
-

Quick example

-
library(survdnn)
-library(survival, quietly = TRUE)
-library(ggplot2)
-
-veteran <- survival::veteran
-
-mod <- survdnn(
-  Surv(time, status) ~ age + karno + celltype,
-  data = veteran,
-  hidden = c(32, 16),
-  epochs = 300,
-  loss = "cox",
-  verbose = TRUE
-  )
-
## Epoch 50 - Loss: 3.983201
-## 
-## Epoch 100 - Loss: 3.947356
-## 
-## Epoch 150 - Loss: 3.934828
-## 
-## Epoch 200 - Loss: 3.876191
-## 
-## Epoch 250 - Loss: 3.813223
-## 
-## Epoch 300 - Loss: 3.868888
-
summary(mod)
-
## 
-## Formula:
-##   Surv(time, status) ~ age + karno + celltype
-## <environment: 0x611459d0ec80>
-## 
-## Model architecture:
-##   Hidden layers:  32 : 16 
-##   Activation:  relu 
-##   Dropout:  0.3 
-##   Final loss:  3.868888 
-## 
-## Training summary:
-##   Epochs:  300 
-##   Learning rate:  1e-04 
-##   Loss function:  cox 
-##   Optimizer:  adam 
-## 
-## Data summary:
-##   Observations:  137 
-##   Predictors:  age, karno, celltypesmallcell, celltypeadeno, celltypelarge 
-##   Time range: [ 1, 999 ]
-##   Event rate:  93.4%
-
plot(mod, group_by = "celltype", times = 1:300)
-

Loss Functions

- -
mod1 <- survdnn(
-  Surv(time, status) ~ age + karno,
-  data = veteran,
-  loss = "cox",
-  epochs = 300
-  )
-
## Epoch 50 - Loss: 3.986035
-## 
-## Epoch 100 - Loss: 3.973183
-## 
-## Epoch 150 - Loss: 3.944867
-## 
-## Epoch 200 - Loss: 3.901533
-## 
-## Epoch 250 - Loss: 3.849433
-## 
-## Epoch 300 - Loss: 3.899746
- -
mod2 <- survdnn(
-  Surv(time, status) ~ age + karno,
-  data = veteran,
-  loss = "aft",
-  epochs = 300
-  )
-
## Epoch 50 - Loss: 18.154217
-## 
-## Epoch 100 - Loss: 17.844833
-## 
-## Epoch 150 - Loss: 17.560537
-## 
-## Epoch 200 - Loss: 17.134348
-## 
-## Epoch 250 - Loss: 16.840366
-## 
-## Epoch 300 - Loss: 16.344124
- -
mod3 <- survdnn(
-  Surv(time, status) ~ age + karno,
-  data = veteran,
-  loss = "coxtime",
-  epochs = 300
-  )
-
## Epoch 50 - Loss: 4.932558
-## 
-## Epoch 100 - Loss: 4.864682
-## 
-## Epoch 150 - Loss: 4.830169
-## 
-## Epoch 200 - Loss: 4.784954
-## 
-## Epoch 250 - Loss: 4.764827
-## 
-## Epoch 300 - Loss: 4.731824
-

Cross-validation

-
cv_results <- cv_survdnn(
-  Surv(time, status) ~ age + karno + celltype,
-  data = veteran,
-  times = c(600),
-  metrics = c("cindex", "ibs"),
-  folds = 3,
-  hidden = c(16, 8),
-  loss = "cox",
-  epochs = 300
-  )
-
-print(cv_results)
-

Hyperparameter tuning

-
grid <- list(
-  hidden     = list(c(16), c(32, 16)),
-  lr         = c(1e-3),
-  activation = c("relu"),
-  epochs     = c(100, 300),
-  loss       = c("cox", "aft", "coxtime")
-  )
-
-tune_res <- tune_survdnn(
-  formula = Surv(time, status) ~ age + karno + celltype,
-  data = veteran,
-  times = c(90, 300),
-  metrics = "cindex",
-  param_grid = grid,
-  folds = 3,
-  refit = FALSE,
-  return = "summary"
-  )
-
-print(tune_res)
-

Tuning and refitting the -best Model

-

tune_survdnn() can be used also to automatically refit -the best-performing model on the full dataset. This behavior is -controlled by the refit and return arguments. -For example:

-
best_model <- tune_survdnn(
-  formula = Surv(time, status) ~ age + karno + celltype,
-  data = veteran,
-  times = c(90, 300),
-  metrics = "cindex",
-  param_grid = grid,
-  folds = 3,
-  refit = TRUE,
-  return = "best_model"
-  )
-

In this mode, cross-validation is used to select the optimal -hyperparameter configuration, after which the selected model is refitted -on the full dataset. The function then returns a fitted object of class -"survdnn".

-

The resulting model can be used directly for prediction -visualization, and evaluation:

-
summary(best_model)
-
-plot(best_model, times = 1:300)
-
-predict(best_model, veteran, type = "risk", times = 180)
-

This makes tune_survdnn() suitable for end-to-end -workflows, combining model selection and final model fitting.

-

Plot survival curves

-
plot(mod1, group_by = "celltype", times = 1:300)
-

-
plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
-

-

Documentation

-
help(package = "survdnn")
-?survdnn
-?tune_survdnn
-?cv_survdnn
-?plot.survdnn
-

Testing

-
# run all tests
-devtools::test()
-

Note on reproducibility

-

By default, {torch} initializes model weights and -shuffles minibatches using random draws, so results may differ across -runs. Unlike set.seed(), which only controls R’s random -number generator, {torch} relies on its own RNG implemented -in C++ (and CUDA when using GPUs).

-

To ensure reproducibility, random seeds must therefore be set at the -Torch level as well.

-

survdnn provides built-in control of randomness to -guarantee reproducible results across runs. The main fitting function, -survdnn(), exposes a dedicated .seed -argument:

-
mod <- survdnn(
-  Surv(time, status) ~ age + karno + celltype,
-  data   = veteran,
-  epochs = 300,
-  .seed  = 123
-)
-

When .seed is provided, survdnn() -internally synchronizes both R and Torch random number generators via -survdnn_set_seed(), ensuring reproducible:

- -

If .seed = NULL (the default), randomness is left -uncontrolled and results may vary between runs.

-

For full reproducibility in cross-validation or hyperparameter -tuning, the same .seed mechanism is propagated internally -by cv_survdnn() and tune_survdnn(), ensuring -consistent data splits, model initialization, and optimization paths -across repetitions.

-

CPU and core usage

-

survdnn relies on the {torch} backend for -numerical computation. The number of CPU cores (threads) used during -training, prediction, and evaluation is controlled globally by -Torch.

-

By default, Torch automatically configures its CPU thread pools based -on the available system resources, unless explicitly overridden by the -user using:

-
torch::torch_set_num_threads(4)
-

This setting affects:

- -

GPU acceleration can be enabled by setting -.device = "cuda" when calling survdnn() -(cv_survdnn() and tune_survdnn() too).

-

Availability

-

The survdnn R package is available on CRAN or github

-

Contributions

-

Contributions, issues, and feature requests are welcome!

-

Open an issue or submit a -pull request.

-

License

-

MIT License © 2025 Imad EL BADISY

- - - diff --git a/README.md b/README.md index 53c668e..1d7dba4 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # survdnn -> Deep Neural Networks for Survival Analysis using -> [torch](https://torch.mlverse.org/) +> Deep Neural Networks for Survival Analysis using [R +> torch](https://torch.mlverse.org/) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) @@ -48,7 +48,7 @@ evaluation of `survdnn` is currently under review at *The R Journal*. ``` r # Install from CRAN -install.packages("surdnn") +install.packages("survdnn") # Install from GitHub @@ -80,17 +80,17 @@ mod <- survdnn( ) ``` - ## Epoch 50 - Loss: 3.983201 + ## Epoch 50 - Loss: 3.967377 ## - ## Epoch 100 - Loss: 3.947356 + ## Epoch 100 - Loss: 3.863189 ## - ## Epoch 150 - Loss: 3.934828 + ## Epoch 150 - Loss: 3.879065 ## - ## Epoch 200 - Loss: 3.876191 + ## Epoch 200 - Loss: 3.814478 ## - ## Epoch 250 - Loss: 3.813223 + ## Epoch 250 - Loss: 3.756944 ## - ## Epoch 300 - Loss: 3.868888 + ## Epoch 300 - Loss: 3.823366 ``` r summary(mod) @@ -99,13 +99,13 @@ summary(mod) ## ## Formula: ## Surv(time, status) ~ age + karno + celltype - ## + ## ## ## Model architecture: ## Hidden layers: 32 : 16 ## Activation: relu ## Dropout: 0.3 - ## Final loss: 3.868888 + ## Final loss: 3.823366 ## ## Training summary: ## Epochs: 300 @@ -136,17 +136,17 @@ mod1 <- survdnn( ) ``` - ## Epoch 50 - Loss: 3.986035 + ## Epoch 50 - Loss: 3.988259 ## - ## Epoch 100 - Loss: 3.973183 + ## Epoch 100 - Loss: 3.930287 ## - ## Epoch 150 - Loss: 3.944867 + ## Epoch 150 - Loss: 3.913787 ## - ## Epoch 200 - Loss: 3.901533 + ## Epoch 200 - Loss: 3.896528 ## - ## Epoch 250 - Loss: 3.849433 + ## Epoch 250 - Loss: 3.819792 ## - ## Epoch 300 - Loss: 3.899746 + ## Epoch 300 - Loss: 3.893889 - Accelerated Failure Time @@ -159,17 +159,17 @@ mod2 <- survdnn( ) ``` - ## Epoch 50 - Loss: 18.154217 + ## Epoch 50 - Loss: 16.911470 ## - ## Epoch 100 - Loss: 17.844833 + ## Epoch 100 - Loss: 16.589067 ## - ## Epoch 150 - Loss: 17.560537 + ## Epoch 150 - Loss: 16.226612 ## - ## Epoch 200 - Loss: 17.134348 + ## Epoch 200 - Loss: 15.959708 ## - ## Epoch 250 - Loss: 16.840366 + ## Epoch 250 - Loss: 15.182121 ## - ## Epoch 300 - Loss: 16.344124 + ## Epoch 300 - Loss: 15.049762 - Coxtime @@ -182,17 +182,17 @@ mod3 <- survdnn( ) ``` - ## Epoch 50 - Loss: 4.932558 + ## Epoch 50 - Loss: 4.888907 ## - ## Epoch 100 - Loss: 4.864682 + ## Epoch 100 - Loss: 4.846722 ## - ## Epoch 150 - Loss: 4.830169 + ## Epoch 150 - Loss: 4.838490 ## - ## Epoch 200 - Loss: 4.784954 + ## Epoch 200 - Loss: 4.816662 ## - ## Epoch 250 - Loss: 4.764827 + ## Epoch 250 - Loss: 4.780379 ## - ## Epoch 300 - Loss: 4.731824 + ## Epoch 300 - Loss: 4.756117 ## Cross-validation diff --git a/README_files/figure-gfm/unnamed-chunk-11-1.png b/README_files/figure-gfm/unnamed-chunk-11-1.png index b79fc28..1ae8cf2 100644 Binary files a/README_files/figure-gfm/unnamed-chunk-11-1.png and b/README_files/figure-gfm/unnamed-chunk-11-1.png differ diff --git a/README_files/figure-gfm/unnamed-chunk-12-1.png b/README_files/figure-gfm/unnamed-chunk-12-1.png index 744d754..0954982 100644 Binary files a/README_files/figure-gfm/unnamed-chunk-12-1.png and b/README_files/figure-gfm/unnamed-chunk-12-1.png differ diff --git a/man/build_dnn.Rd b/man/build_dnn.Rd index e6af471..b6fc1a6 100644 --- a/man/build_dnn.Rd +++ b/man/build_dnn.Rd @@ -19,7 +19,8 @@ build_dnn( \item{hidden}{Integer vector. Sizes of the hidden layers (e.g., c(32, 16)).} \item{activation}{Character. Name of the activation function to use in each layer. -Supported options: `"relu"`, `"leaky_relu"`, `"tanh"`, `"sigmoid"`, `"gelu"`, `"elu"`, `"softplus"`.} +Supported options: `"relu"`, `"leaky_relu"`, `"tanh"`, `"sigmoid"`, `"gelu"`, +`"elu"`, `"softplus"`.} \item{output_dim}{Integer. Output layer dimension (default = 1).} diff --git a/man/survdnn.Rd b/man/survdnn.Rd index 0e3385f..f62f471 100644 --- a/man/survdnn.Rd +++ b/man/survdnn.Rd @@ -24,76 +24,147 @@ survdnn( ) } \arguments{ -\item{formula}{A survival formula of the form `Surv(time, status) ~ predictors`.} +\item{formula}{A survival formula of the form \code{Surv(time, status) ~ predictors}.} \item{data}{A data frame containing the variables in the model.} -\item{hidden}{Integer vector. Sizes of the hidden layers (default: c(32, 16)).} +\item{hidden}{Integer vector giving hidden layer widths (e.g., \code{c(32L, 16L)}).} -\item{activation}{Character string specifying the activation function to use in each layer. -Supported options: `"relu"`, `"leaky_relu"`, `"tanh"`, `"sigmoid"`, `"gelu"`, `"elu"`, `"softplus"`.} +\item{activation}{Activation function used in each hidden layer. One of +\code{"relu"}, \code{"leaky_relu"}, \code{"tanh"}, \code{"sigmoid"}, +\code{"gelu"}, \code{"elu"}, \code{"softplus"}.} -\item{lr}{Learning rate for the optimizer (default: `1e-4`).} +\item{lr}{Learning rate passed to the optimizer (default \code{1e-4}).} -\item{epochs}{Number of training epochs (default: 300).} +\item{epochs}{Number of training epochs (default \code{300L}).} -\item{loss}{Character name of the loss function to use. One of `"cox"`, `"cox_l2"`, `"aft"`, or `"coxtime"`.} +\item{loss}{Loss function to optimize. One of \code{"cox"}, \code{"cox_l2"}, +\code{"aft"}, \code{"coxtime"}.} -\item{optimizer}{Character string specifying the optimizer to use. One of -`"adam"`, `"adamw"`, `"sgd"`, `"rmsprop"`, or `"adagrad"`. Defaults to `"adam"`.} +\item{optimizer}{Optimizer name. One of \code{"adam"}, \code{"adamw"}, +\code{"sgd"}, \code{"rmsprop"}, \code{"adagrad"}.} -\item{optim_args}{Optional named list of additional arguments passed to the -underlying torch optimizer (e.g., `list(weight_decay = 1e-4, momentum = 0.9)`).} +\item{optim_args}{Optional named list of extra arguments passed to the chosen +torch optimizer (e.g., \code{list(weight_decay = 1e-4, momentum = 0.9)}).} -\item{verbose}{Logical; whether to print loss progress every 50 epochs (default: TRUE).} +\item{verbose}{Logical; whether to print training progress every 50 epochs.} -\item{dropout}{Numeric between 0 and 1. Dropout rate applied after each -hidden layer (default = 0.3). Set to 0 to disable dropout.} +\item{dropout}{Dropout rate applied after each hidden layer (set \code{0} to disable).} -\item{batch_norm}{Logical; whether to add batch normalization after each -hidden linear layer (default = TRUE).} +\item{batch_norm}{Logical; whether to add batch normalization after each hidden linear layer.} -\item{callbacks}{Optional list of callback functions. Each callback should have -signature `function(epoch, current)` and return TRUE if training should stop, -FALSE otherwise. Used, for example, with [callback_early_stopping()].} +\item{callbacks}{Optional callback(s) for early stopping or monitoring. +May be \code{NULL}, a single function, or a list of functions. Each callback must have +signature \code{function(epoch, current_loss)} and return \code{TRUE} to stop training, +\code{FALSE} otherwise.} -\item{.seed}{Optional integer. If provided, sets both R and torch random seeds for reproducible -weight initialization, shuffling, and dropout.} +\item{.seed}{Optional integer seed controlling both R and torch RNGs (weight init, +shuffling, dropout) for reproducibility.} -\item{.device}{Character string indicating the computation device. -One of `"auto"`, `"cpu"`, or `"cuda"`. `"auto"` uses CUDA if available, -otherwise falls back to CPU.} +\item{.device}{Computation device. One of \code{"auto"}, \code{"cpu"}, \code{"cuda"}. +\code{"auto"} selects CUDA when available.} -\item{na_action}{Character. How to handle missing values in the model variables: -`"omit"` drops incomplete rows (and reports how many were removed when `verbose=TRUE`); -`"fail"` stops with an error if any missing values are present.} +\item{na_action}{Missing-data handling. \code{"omit"} drops incomplete rows (and reports +how many were removed when \code{verbose=TRUE}); \code{"fail"} errors if any missing +values are present in model variables.} } \value{ -An object of class `"survdnn"` containing: +An object of class \code{"survdnn"} with components: \describe{ - \item{model}{Trained `nn_module` object.} - \item{formula}{Original survival formula.} - \item{data}{Training data used for fitting.} - \item{xnames}{Predictor variable names.} - \item{x_center}{Column means of predictors.} - \item{x_scale}{Column standard deviations of predictors.} - \item{loss_history}{Vector of loss values per epoch.} - \item{final_loss}{Final training loss.} - \item{loss}{Loss function name used ("cox", "aft", etc.).} - \item{activation}{Activation function used.} + \item{model}{Trained torch \code{nn_module} (MLP).} + \item{formula}{Model formula used for fitting.} + \item{data}{Training data used for fitting (original \code{data} argument).} + \item{xnames}{Predictor column names used by the model matrix.} + \item{x_center}{Numeric vector of predictor means used for scaling.} + \item{x_scale}{Numeric vector of predictor standard deviations used for scaling.} + \item{loss_history}{Numeric vector of loss values per epoch (possibly truncated by early stopping).} + \item{final_loss}{Final loss value (last element of \code{loss_history}).} + \item{loss}{Loss name used for training.} + \item{activation}{Activation function name.} \item{hidden}{Hidden layer sizes.} \item{lr}{Learning rate.} - \item{epochs}{Number of training epochs.} - \item{optimizer}{Optimizer name used.} + \item{epochs}{Number of requested epochs.} + \item{optimizer}{Optimizer name.} \item{optim_args}{List of optimizer arguments used.} - \item{device}{Torch device used for training (`torch_device`).} - \item{aft_log_sigma}{Learned global log(sigma) for `loss="aft"`; `NA_real_` otherwise.} - \item{aft_loc}{AFT log-time location offset used for centering when `loss="aft"`; `NA_real_` otherwise.} - \item{coxtime_time_center}{Mean used to scale time for CoxTime; `NA_real_` otherwise.} - \item{coxtime_time_scale}{SD used to scale time for CoxTime; `NA_real_` otherwise.} + \item{device}{Torch device used for fitting.} + \item{dropout}{Dropout rate used.} + \item{batch_norm}{Whether batch normalization was used.} + \item{na_action}{Missing-data strategy used.} + \item{aft_log_sigma}{Learned global \code{log(sigma)} for AFT; \code{NA_real_} otherwise.} + \item{aft_loc}{Log-time centering offset used for AFT; \code{NA_real_} otherwise.} + \item{coxtime_time_center}{Time centering used for CoxTime; \code{NA_real_} otherwise.} + \item{coxtime_time_scale}{Time scaling used for CoxTime; \code{NA_real_} otherwise.} } } \description{ -Trains a deep neural network (DNN) to model right-censored survival data -using one of the predefined loss functions: Cox, AFT, or Coxtime. +Fits a deep neural network (MLP) for right-censored time-to-event data using +one of the supported losses: Cox partial likelihood, L2-penalized Cox, +log-normal AFT (censored negative log-likelihood), or CoxTime (time-dependent +relative risk model). +} +\details{ +The function: +\itemize{ + \item builds an MLP via [build_dnn()], + \item preprocesses predictors using centering/scaling (stored in the model), + \item optionally applies log-time centering for AFT (stored as \code{aft_loc}), + \item scales time for CoxTime to stabilize optimization (stored as \code{coxtime_time_center}/\code{coxtime_time_scale}), + \item trains the network with a torch optimizer and optional callbacks. +} + + +\strong{AFT model.} With \code{loss="aft"}, the model is a log-normal AFT model: +\deqn{\log(T) = \text{aft\_loc} + \mu_{\text{resid}}(x) + \sigma \varepsilon, \quad \varepsilon \sim \mathcal{N}(0,1).} +For numerical stability, training uses centered log-times +\code{log(time) - aft_loc}. The learned network output corresponds to +\code{mu_resid(x)}. The fitted object stores \code{aft_loc} and the learned global +\code{aft_log_sigma}. + +\strong{CoxTime.} With \code{loss="coxtime"}, the network represents a time-dependent +score \eqn{g(t, x)}. Internally, time is standardized before being concatenated with +standardized covariates. The scaling parameters are stored as +\code{coxtime_time_center} and \code{coxtime_time_scale} to ensure prediction uses the +same transformation. +} +\examples{ +\donttest{ +if (torch::torch_is_installed()) { + veteran <- survival::veteran + + # --- Cox model --- + fit_cox <- survdnn( + Surv(time, status) ~ age + karno + celltype, + data = veteran, + epochs = 50, + verbose = FALSE, + .seed = 1 + ) + lp <- predict(fit_cox, newdata = veteran, type = "lp") + S <- predict(fit_cox, newdata = veteran, type = "survival", times = c(30, 90, 180)) + + # --- AFT log-normal model --- + fit_aft <- survdnn( + Surv(time, status) ~ age + karno + celltype, + data = veteran, + loss = "aft", + epochs = 50, + verbose = FALSE, + .seed = 1 + ) + S_aft <- predict(fit_aft, newdata = veteran, type = "survival", times = c(30, 90, 180)) + + # --- CoxTime model --- + fit_ct <- survdnn( + Surv(time, status) ~ age + karno + celltype, + data = veteran, + loss = "coxtime", + epochs = 50, + verbose = FALSE, + .seed = 1 + ) + # By default, CoxTime survival predictions can use event times if times=NULL + S_ct <- predict(fit_ct, newdata = veteran, type = "survival") +} +} + }