diff --git a/R/evaluation.R b/R/evaluation.R
index 709ecd0..fc1b589 100644
--- a/R/evaluation.R
+++ b/R/evaluation.R
@@ -34,7 +34,7 @@ evaluate_survdnn <- function(model,
data <- if (is.null(newdata)) model$data else newdata
n_before <- nrow(data)
- # Build model frame first with explicit NA policy (so y aligns with predictions)
+ # build model frame first with explicit NA policy
mf <- model.frame(
model$formula,
data = data,
@@ -50,10 +50,10 @@ evaluate_survdnn <- function(model,
y <- model.response(mf)
if (!inherits(y, "Surv")) stop("The response must be a 'Surv' object.", call. = FALSE)
- # Predict on the filtered mf to keep row alignment
+ # predict on the filtered mf to keep row alignment
sp_matrix <- predict(model, newdata = mf, times = times, type = "survival")
- purrr::map_dfr(metrics, function(metric) {
+ purrr::map_dfr(metrics, function(metric) { ## to replace with fmap from functionals package
if (metric == "brier" && length(times) > 1) {
tibble::tibble(
metric = "brier",
diff --git a/R/tune_survdnn.R b/R/tune_survdnn.R
index 46cbbaa..d0ed91d 100644
--- a/R/tune_survdnn.R
+++ b/R/tune_survdnn.R
@@ -83,7 +83,7 @@ tune_survdnn <- function(formula,
summary_tbl <- summarize_tune_survdnn(all_results, by_time = FALSE)
- ## Select best hyperparameters
+ ## select best hyperparameters
primary_metric <- metrics[1]
best_row_all <- all_results |>
@@ -99,7 +99,7 @@ tune_survdnn <- function(formula,
stop("No valid configuration found for primary metric: ", primary_metric, call. = FALSE)
}
- ## Refitting the best model
+ ## refitting the best model
if (refit) {
message("Refitting best model on full data...")
best_model <- survdnn(
diff --git a/R/zzz.R b/R/zzz.R
index 741f09c..460d1fc 100644
--- a/R/zzz.R
+++ b/R/zzz.R
@@ -9,14 +9,11 @@
"fold", "metric", "value", "id", "time", "surv", "group", "mean_surv",
"n", "se", "hidden", "lr", "activation", "epochs", "loss_name", ".loss_fn"
))
-
- # IMPORTANT: never load or probe torch here (because CRAN/Windows may segfault).
- # No torch checks, no tensor creation on load.
+ # never load or probe torch here (because CRAN/Windows may segfault)
}
## handles user-facing messaging
.onAttach <- function(libname, pkgname) {
- # Do NOT load torch or call torch::torch_is_installed() here.
# friendly hint that doesn't load the namespace:
torch_pkg_present <- nzchar(system.file(package = "torch"))
@@ -51,7 +48,7 @@ survdnn_set_seed <- function(.seed = NULL) {
-## Internal utility to choose a torch device for survdnn
+## internal utility to choose a torch device for survdnn
survdnn_get_device <- function(.device = c("auto", "cpu", "cuda")) {
.device <- match.arg(.device)
diff --git a/README.Rmd b/README.Rmd
index 14faaf1..16bdd2a 100644
--- a/README.Rmd
+++ b/README.Rmd
@@ -4,7 +4,7 @@ output: github_document
# survdnn
-> Deep Neural Networks for Survival Analysis using [torch](https://torch.mlverse.org/)
+> Deep Neural Networks for Survival Analysis using [R torch](https://torch.mlverse.org/)
[](LICENSE)
[](https://github.com/ielbadisy/survdnn/actions/workflows/R-CMD-check.yaml)
@@ -43,7 +43,7 @@ A methodological paper describing the design, implementation, and evaluation of
```{r, eval = FALSE}
# Install from CRAN
-install.packages("surdnn")
+install.packages("survdnn")
# Install from GitHub
diff --git a/README.html b/README.html
deleted file mode 100644
index 346a150..0000000
--- a/README.html
+++ /dev/null
@@ -1,906 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-survdnn
-
-
-Deep Neural Networks for Survival Analysis using torch
-
-
-
-About
-survdnn implements neural network-based models for
-right-censored survival analysis using the native torch
-backend in R. It supports multiple loss functions including Cox partial
-likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives,
-as well as time-dependent extension such as Cox-Time. The package
-provides a formula interface, supports model evaluation using
-time-dependent metrics (C-index, Brier score, IBS), cross-validation,
-and hyperparameter tuning.
-Review status
-A methodological paper describing the design, implementation, and
-evaluation of survdnn is currently under review at The
-R Journal.
-Main features
-
-Formula interface for Surv() ~ . models
-Modular neural architectures: configurable layers, activations,
-optimizers, and losses
-Built-in survival loss functions:
-
-"cox": Cox partial likelihood
-"cox_l2": penalized Cox
-"aft": Accelerated Failure Time
-"coxtime": deep time-dependent Cox
-
-Evaluation: C-index, Brier score, IBS
-Model selection with cv_survdnn() and
-tune_survdnn()
-Prediction of survival curves via predict() and
-plot()
-
-Installation
-# Install from CRAN
-install.packages("surdnn")
-
-
-# Install from GitHub
-install.packages("remotes")
-remotes::install_github("ielbadisy/survdnn")
-
-# Or clone and install locally
-git clone https://github.com/ielbadisy/survdnn.git
-setwd("survdnn")
-devtools::install()
-Quick example
-library(survdnn)
-library(survival, quietly = TRUE)
-library(ggplot2)
-
-veteran <- survival::veteran
-
-mod <- survdnn(
- Surv(time, status) ~ age + karno + celltype,
- data = veteran,
- hidden = c(32, 16),
- epochs = 300,
- loss = "cox",
- verbose = TRUE
- )
-## Epoch 50 - Loss: 3.983201
-##
-## Epoch 100 - Loss: 3.947356
-##
-## Epoch 150 - Loss: 3.934828
-##
-## Epoch 200 - Loss: 3.876191
-##
-## Epoch 250 - Loss: 3.813223
-##
-## Epoch 300 - Loss: 3.868888
-
-##
-## Formula:
-## Surv(time, status) ~ age + karno + celltype
-## <environment: 0x611459d0ec80>
-##
-## Model architecture:
-## Hidden layers: 32 : 16
-## Activation: relu
-## Dropout: 0.3
-## Final loss: 3.868888
-##
-## Training summary:
-## Epochs: 300
-## Learning rate: 1e-04
-## Loss function: cox
-## Optimizer: adam
-##
-## Data summary:
-## Observations: 137
-## Predictors: age, karno, celltypesmallcell, celltypeadeno, celltypelarge
-## Time range: [ 1, 999 ]
-## Event rate: 93.4%
-plot(mod, group_by = "celltype", times = 1:300)
-Loss Functions
-
-- Cox partial likelihood
-
-mod1 <- survdnn(
- Surv(time, status) ~ age + karno,
- data = veteran,
- loss = "cox",
- epochs = 300
- )
-## Epoch 50 - Loss: 3.986035
-##
-## Epoch 100 - Loss: 3.973183
-##
-## Epoch 150 - Loss: 3.944867
-##
-## Epoch 200 - Loss: 3.901533
-##
-## Epoch 250 - Loss: 3.849433
-##
-## Epoch 300 - Loss: 3.899746
-
-- Accelerated Failure Time
-
-mod2 <- survdnn(
- Surv(time, status) ~ age + karno,
- data = veteran,
- loss = "aft",
- epochs = 300
- )
-## Epoch 50 - Loss: 18.154217
-##
-## Epoch 100 - Loss: 17.844833
-##
-## Epoch 150 - Loss: 17.560537
-##
-## Epoch 200 - Loss: 17.134348
-##
-## Epoch 250 - Loss: 16.840366
-##
-## Epoch 300 - Loss: 16.344124
-
-mod3 <- survdnn(
- Surv(time, status) ~ age + karno,
- data = veteran,
- loss = "coxtime",
- epochs = 300
- )
-## Epoch 50 - Loss: 4.932558
-##
-## Epoch 100 - Loss: 4.864682
-##
-## Epoch 150 - Loss: 4.830169
-##
-## Epoch 200 - Loss: 4.784954
-##
-## Epoch 250 - Loss: 4.764827
-##
-## Epoch 300 - Loss: 4.731824
-Cross-validation
-cv_results <- cv_survdnn(
- Surv(time, status) ~ age + karno + celltype,
- data = veteran,
- times = c(600),
- metrics = c("cindex", "ibs"),
- folds = 3,
- hidden = c(16, 8),
- loss = "cox",
- epochs = 300
- )
-
-print(cv_results)
-Hyperparameter tuning
-grid <- list(
- hidden = list(c(16), c(32, 16)),
- lr = c(1e-3),
- activation = c("relu"),
- epochs = c(100, 300),
- loss = c("cox", "aft", "coxtime")
- )
-
-tune_res <- tune_survdnn(
- formula = Surv(time, status) ~ age + karno + celltype,
- data = veteran,
- times = c(90, 300),
- metrics = "cindex",
- param_grid = grid,
- folds = 3,
- refit = FALSE,
- return = "summary"
- )
-
-print(tune_res)
-Tuning and refitting the
-best Model
-tune_survdnn() can be used also to automatically refit
-the best-performing model on the full dataset. This behavior is
-controlled by the refit and return arguments.
-For example:
-best_model <- tune_survdnn(
- formula = Surv(time, status) ~ age + karno + celltype,
- data = veteran,
- times = c(90, 300),
- metrics = "cindex",
- param_grid = grid,
- folds = 3,
- refit = TRUE,
- return = "best_model"
- )
-In this mode, cross-validation is used to select the optimal
-hyperparameter configuration, after which the selected model is refitted
-on the full dataset. The function then returns a fitted object of class
-"survdnn".
-The resulting model can be used directly for prediction
-visualization, and evaluation:
-summary(best_model)
-
-plot(best_model, times = 1:300)
-
-predict(best_model, veteran, type = "risk", times = 180)
-This makes tune_survdnn() suitable for end-to-end
-workflows, combining model selection and final model fitting.
-Plot survival curves
-plot(mod1, group_by = "celltype", times = 1:300)
-
-plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
-
-Documentation
-help(package = "survdnn")
-?survdnn
-?tune_survdnn
-?cv_survdnn
-?plot.survdnn
-Testing
-# run all tests
-devtools::test()
-Note on reproducibility
-By default, {torch} initializes model weights and
-shuffles minibatches using random draws, so results may differ across
-runs. Unlike set.seed(), which only controls R’s random
-number generator, {torch} relies on its own RNG implemented
-in C++ (and CUDA when using GPUs).
-To ensure reproducibility, random seeds must therefore be set at the
-Torch level as well.
-survdnn provides built-in control of randomness to
-guarantee reproducible results across runs. The main fitting function,
-survdnn(), exposes a dedicated .seed
-argument:
-mod <- survdnn(
- Surv(time, status) ~ age + karno + celltype,
- data = veteran,
- epochs = 300,
- .seed = 123
-)
-When .seed is provided, survdnn()
-internally synchronizes both R and Torch random number generators via
-survdnn_set_seed(), ensuring reproducible:
-
-weight initialization
-dropout behavior
-minibatch ordering
-loss trajectories
-
-If .seed = NULL (the default), randomness is left
-uncontrolled and results may vary between runs.
-For full reproducibility in cross-validation or hyperparameter
-tuning, the same .seed mechanism is propagated internally
-by cv_survdnn() and tune_survdnn(), ensuring
-consistent data splits, model initialization, and optimization paths
-across repetitions.
-CPU and core usage
-survdnn relies on the {torch} backend for
-numerical computation. The number of CPU cores (threads) used during
-training, prediction, and evaluation is controlled globally by
-Torch.
-By default, Torch automatically configures its CPU thread pools based
-on the available system resources, unless explicitly overridden by the
-user using:
-torch::torch_set_num_threads(4)
-This setting affects:
-
-GPU acceleration can be enabled by setting
-.device = "cuda" when calling survdnn()
-(cv_survdnn() and tune_survdnn() too).
-Availability
-The survdnn R package is available on CRAN or github
-Contributions
-Contributions, issues, and feature requests are welcome!
-Open an issue or submit a
-pull request.
-License
-MIT License © 2025 Imad EL BADISY
-
-
-
diff --git a/README.md b/README.md
index 53c668e..1d7dba4 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
# survdnn
-> Deep Neural Networks for Survival Analysis using
-> [torch](https://torch.mlverse.org/)
+> Deep Neural Networks for Survival Analysis using [R
+> torch](https://torch.mlverse.org/)
[](LICENSE)
@@ -48,7 +48,7 @@ evaluation of `survdnn` is currently under review at *The R Journal*.
``` r
# Install from CRAN
-install.packages("surdnn")
+install.packages("survdnn")
# Install from GitHub
@@ -80,17 +80,17 @@ mod <- survdnn(
)
```
- ## Epoch 50 - Loss: 3.983201
+ ## Epoch 50 - Loss: 3.967377
##
- ## Epoch 100 - Loss: 3.947356
+ ## Epoch 100 - Loss: 3.863189
##
- ## Epoch 150 - Loss: 3.934828
+ ## Epoch 150 - Loss: 3.879065
##
- ## Epoch 200 - Loss: 3.876191
+ ## Epoch 200 - Loss: 3.814478
##
- ## Epoch 250 - Loss: 3.813223
+ ## Epoch 250 - Loss: 3.756944
##
- ## Epoch 300 - Loss: 3.868888
+ ## Epoch 300 - Loss: 3.823366
``` r
summary(mod)
@@ -99,13 +99,13 @@ summary(mod)
##
## Formula:
## Surv(time, status) ~ age + karno + celltype
- ##
+ ##
##
## Model architecture:
## Hidden layers: 32 : 16
## Activation: relu
## Dropout: 0.3
- ## Final loss: 3.868888
+ ## Final loss: 3.823366
##
## Training summary:
## Epochs: 300
@@ -136,17 +136,17 @@ mod1 <- survdnn(
)
```
- ## Epoch 50 - Loss: 3.986035
+ ## Epoch 50 - Loss: 3.988259
##
- ## Epoch 100 - Loss: 3.973183
+ ## Epoch 100 - Loss: 3.930287
##
- ## Epoch 150 - Loss: 3.944867
+ ## Epoch 150 - Loss: 3.913787
##
- ## Epoch 200 - Loss: 3.901533
+ ## Epoch 200 - Loss: 3.896528
##
- ## Epoch 250 - Loss: 3.849433
+ ## Epoch 250 - Loss: 3.819792
##
- ## Epoch 300 - Loss: 3.899746
+ ## Epoch 300 - Loss: 3.893889
- Accelerated Failure Time
@@ -159,17 +159,17 @@ mod2 <- survdnn(
)
```
- ## Epoch 50 - Loss: 18.154217
+ ## Epoch 50 - Loss: 16.911470
##
- ## Epoch 100 - Loss: 17.844833
+ ## Epoch 100 - Loss: 16.589067
##
- ## Epoch 150 - Loss: 17.560537
+ ## Epoch 150 - Loss: 16.226612
##
- ## Epoch 200 - Loss: 17.134348
+ ## Epoch 200 - Loss: 15.959708
##
- ## Epoch 250 - Loss: 16.840366
+ ## Epoch 250 - Loss: 15.182121
##
- ## Epoch 300 - Loss: 16.344124
+ ## Epoch 300 - Loss: 15.049762
- Coxtime
@@ -182,17 +182,17 @@ mod3 <- survdnn(
)
```
- ## Epoch 50 - Loss: 4.932558
+ ## Epoch 50 - Loss: 4.888907
##
- ## Epoch 100 - Loss: 4.864682
+ ## Epoch 100 - Loss: 4.846722
##
- ## Epoch 150 - Loss: 4.830169
+ ## Epoch 150 - Loss: 4.838490
##
- ## Epoch 200 - Loss: 4.784954
+ ## Epoch 200 - Loss: 4.816662
##
- ## Epoch 250 - Loss: 4.764827
+ ## Epoch 250 - Loss: 4.780379
##
- ## Epoch 300 - Loss: 4.731824
+ ## Epoch 300 - Loss: 4.756117
## Cross-validation
diff --git a/README_files/figure-gfm/unnamed-chunk-11-1.png b/README_files/figure-gfm/unnamed-chunk-11-1.png
index b79fc28..1ae8cf2 100644
Binary files a/README_files/figure-gfm/unnamed-chunk-11-1.png and b/README_files/figure-gfm/unnamed-chunk-11-1.png differ
diff --git a/README_files/figure-gfm/unnamed-chunk-12-1.png b/README_files/figure-gfm/unnamed-chunk-12-1.png
index 744d754..0954982 100644
Binary files a/README_files/figure-gfm/unnamed-chunk-12-1.png and b/README_files/figure-gfm/unnamed-chunk-12-1.png differ
diff --git a/man/build_dnn.Rd b/man/build_dnn.Rd
index e6af471..b6fc1a6 100644
--- a/man/build_dnn.Rd
+++ b/man/build_dnn.Rd
@@ -19,7 +19,8 @@ build_dnn(
\item{hidden}{Integer vector. Sizes of the hidden layers (e.g., c(32, 16)).}
\item{activation}{Character. Name of the activation function to use in each layer.
-Supported options: `"relu"`, `"leaky_relu"`, `"tanh"`, `"sigmoid"`, `"gelu"`, `"elu"`, `"softplus"`.}
+Supported options: `"relu"`, `"leaky_relu"`, `"tanh"`, `"sigmoid"`, `"gelu"`,
+`"elu"`, `"softplus"`.}
\item{output_dim}{Integer. Output layer dimension (default = 1).}
diff --git a/man/survdnn.Rd b/man/survdnn.Rd
index 0e3385f..f62f471 100644
--- a/man/survdnn.Rd
+++ b/man/survdnn.Rd
@@ -24,76 +24,147 @@ survdnn(
)
}
\arguments{
-\item{formula}{A survival formula of the form `Surv(time, status) ~ predictors`.}
+\item{formula}{A survival formula of the form \code{Surv(time, status) ~ predictors}.}
\item{data}{A data frame containing the variables in the model.}
-\item{hidden}{Integer vector. Sizes of the hidden layers (default: c(32, 16)).}
+\item{hidden}{Integer vector giving hidden layer widths (e.g., \code{c(32L, 16L)}).}
-\item{activation}{Character string specifying the activation function to use in each layer.
-Supported options: `"relu"`, `"leaky_relu"`, `"tanh"`, `"sigmoid"`, `"gelu"`, `"elu"`, `"softplus"`.}
+\item{activation}{Activation function used in each hidden layer. One of
+\code{"relu"}, \code{"leaky_relu"}, \code{"tanh"}, \code{"sigmoid"},
+\code{"gelu"}, \code{"elu"}, \code{"softplus"}.}
-\item{lr}{Learning rate for the optimizer (default: `1e-4`).}
+\item{lr}{Learning rate passed to the optimizer (default \code{1e-4}).}
-\item{epochs}{Number of training epochs (default: 300).}
+\item{epochs}{Number of training epochs (default \code{300L}).}
-\item{loss}{Character name of the loss function to use. One of `"cox"`, `"cox_l2"`, `"aft"`, or `"coxtime"`.}
+\item{loss}{Loss function to optimize. One of \code{"cox"}, \code{"cox_l2"},
+\code{"aft"}, \code{"coxtime"}.}
-\item{optimizer}{Character string specifying the optimizer to use. One of
-`"adam"`, `"adamw"`, `"sgd"`, `"rmsprop"`, or `"adagrad"`. Defaults to `"adam"`.}
+\item{optimizer}{Optimizer name. One of \code{"adam"}, \code{"adamw"},
+\code{"sgd"}, \code{"rmsprop"}, \code{"adagrad"}.}
-\item{optim_args}{Optional named list of additional arguments passed to the
-underlying torch optimizer (e.g., `list(weight_decay = 1e-4, momentum = 0.9)`).}
+\item{optim_args}{Optional named list of extra arguments passed to the chosen
+torch optimizer (e.g., \code{list(weight_decay = 1e-4, momentum = 0.9)}).}
-\item{verbose}{Logical; whether to print loss progress every 50 epochs (default: TRUE).}
+\item{verbose}{Logical; whether to print training progress every 50 epochs.}
-\item{dropout}{Numeric between 0 and 1. Dropout rate applied after each
-hidden layer (default = 0.3). Set to 0 to disable dropout.}
+\item{dropout}{Dropout rate applied after each hidden layer (set \code{0} to disable).}
-\item{batch_norm}{Logical; whether to add batch normalization after each
-hidden linear layer (default = TRUE).}
+\item{batch_norm}{Logical; whether to add batch normalization after each hidden linear layer.}
-\item{callbacks}{Optional list of callback functions. Each callback should have
-signature `function(epoch, current)` and return TRUE if training should stop,
-FALSE otherwise. Used, for example, with [callback_early_stopping()].}
+\item{callbacks}{Optional callback(s) for early stopping or monitoring.
+May be \code{NULL}, a single function, or a list of functions. Each callback must have
+signature \code{function(epoch, current_loss)} and return \code{TRUE} to stop training,
+\code{FALSE} otherwise.}
-\item{.seed}{Optional integer. If provided, sets both R and torch random seeds for reproducible
-weight initialization, shuffling, and dropout.}
+\item{.seed}{Optional integer seed controlling both R and torch RNGs (weight init,
+shuffling, dropout) for reproducibility.}
-\item{.device}{Character string indicating the computation device.
-One of `"auto"`, `"cpu"`, or `"cuda"`. `"auto"` uses CUDA if available,
-otherwise falls back to CPU.}
+\item{.device}{Computation device. One of \code{"auto"}, \code{"cpu"}, \code{"cuda"}.
+\code{"auto"} selects CUDA when available.}
-\item{na_action}{Character. How to handle missing values in the model variables:
-`"omit"` drops incomplete rows (and reports how many were removed when `verbose=TRUE`);
-`"fail"` stops with an error if any missing values are present.}
+\item{na_action}{Missing-data handling. \code{"omit"} drops incomplete rows (and reports
+how many were removed when \code{verbose=TRUE}); \code{"fail"} errors if any missing
+values are present in model variables.}
}
\value{
-An object of class `"survdnn"` containing:
+An object of class \code{"survdnn"} with components:
\describe{
- \item{model}{Trained `nn_module` object.}
- \item{formula}{Original survival formula.}
- \item{data}{Training data used for fitting.}
- \item{xnames}{Predictor variable names.}
- \item{x_center}{Column means of predictors.}
- \item{x_scale}{Column standard deviations of predictors.}
- \item{loss_history}{Vector of loss values per epoch.}
- \item{final_loss}{Final training loss.}
- \item{loss}{Loss function name used ("cox", "aft", etc.).}
- \item{activation}{Activation function used.}
+ \item{model}{Trained torch \code{nn_module} (MLP).}
+ \item{formula}{Model formula used for fitting.}
+ \item{data}{Training data used for fitting (original \code{data} argument).}
+ \item{xnames}{Predictor column names used by the model matrix.}
+ \item{x_center}{Numeric vector of predictor means used for scaling.}
+ \item{x_scale}{Numeric vector of predictor standard deviations used for scaling.}
+ \item{loss_history}{Numeric vector of loss values per epoch (possibly truncated by early stopping).}
+ \item{final_loss}{Final loss value (last element of \code{loss_history}).}
+ \item{loss}{Loss name used for training.}
+ \item{activation}{Activation function name.}
\item{hidden}{Hidden layer sizes.}
\item{lr}{Learning rate.}
- \item{epochs}{Number of training epochs.}
- \item{optimizer}{Optimizer name used.}
+ \item{epochs}{Number of requested epochs.}
+ \item{optimizer}{Optimizer name.}
\item{optim_args}{List of optimizer arguments used.}
- \item{device}{Torch device used for training (`torch_device`).}
- \item{aft_log_sigma}{Learned global log(sigma) for `loss="aft"`; `NA_real_` otherwise.}
- \item{aft_loc}{AFT log-time location offset used for centering when `loss="aft"`; `NA_real_` otherwise.}
- \item{coxtime_time_center}{Mean used to scale time for CoxTime; `NA_real_` otherwise.}
- \item{coxtime_time_scale}{SD used to scale time for CoxTime; `NA_real_` otherwise.}
+ \item{device}{Torch device used for fitting.}
+ \item{dropout}{Dropout rate used.}
+ \item{batch_norm}{Whether batch normalization was used.}
+ \item{na_action}{Missing-data strategy used.}
+ \item{aft_log_sigma}{Learned global \code{log(sigma)} for AFT; \code{NA_real_} otherwise.}
+ \item{aft_loc}{Log-time centering offset used for AFT; \code{NA_real_} otherwise.}
+ \item{coxtime_time_center}{Time centering used for CoxTime; \code{NA_real_} otherwise.}
+ \item{coxtime_time_scale}{Time scaling used for CoxTime; \code{NA_real_} otherwise.}
}
}
\description{
-Trains a deep neural network (DNN) to model right-censored survival data
-using one of the predefined loss functions: Cox, AFT, or Coxtime.
+Fits a deep neural network (MLP) for right-censored time-to-event data using
+one of the supported losses: Cox partial likelihood, L2-penalized Cox,
+log-normal AFT (censored negative log-likelihood), or CoxTime (time-dependent
+relative risk model).
+}
+\details{
+The function:
+\itemize{
+ \item builds an MLP via [build_dnn()],
+ \item preprocesses predictors using centering/scaling (stored in the model),
+ \item optionally applies log-time centering for AFT (stored as \code{aft_loc}),
+ \item scales time for CoxTime to stabilize optimization (stored as \code{coxtime_time_center}/\code{coxtime_time_scale}),
+ \item trains the network with a torch optimizer and optional callbacks.
+}
+
+
+\strong{AFT model.} With \code{loss="aft"}, the model is a log-normal AFT model:
+\deqn{\log(T) = \text{aft\_loc} + \mu_{\text{resid}}(x) + \sigma \varepsilon, \quad \varepsilon \sim \mathcal{N}(0,1).}
+For numerical stability, training uses centered log-times
+\code{log(time) - aft_loc}. The learned network output corresponds to
+\code{mu_resid(x)}. The fitted object stores \code{aft_loc} and the learned global
+\code{aft_log_sigma}.
+
+\strong{CoxTime.} With \code{loss="coxtime"}, the network represents a time-dependent
+score \eqn{g(t, x)}. Internally, time is standardized before being concatenated with
+standardized covariates. The scaling parameters are stored as
+\code{coxtime_time_center} and \code{coxtime_time_scale} to ensure prediction uses the
+same transformation.
+}
+\examples{
+\donttest{
+if (torch::torch_is_installed()) {
+ veteran <- survival::veteran
+
+ # --- Cox model ---
+ fit_cox <- survdnn(
+ Surv(time, status) ~ age + karno + celltype,
+ data = veteran,
+ epochs = 50,
+ verbose = FALSE,
+ .seed = 1
+ )
+ lp <- predict(fit_cox, newdata = veteran, type = "lp")
+ S <- predict(fit_cox, newdata = veteran, type = "survival", times = c(30, 90, 180))
+
+ # --- AFT log-normal model ---
+ fit_aft <- survdnn(
+ Surv(time, status) ~ age + karno + celltype,
+ data = veteran,
+ loss = "aft",
+ epochs = 50,
+ verbose = FALSE,
+ .seed = 1
+ )
+ S_aft <- predict(fit_aft, newdata = veteran, type = "survival", times = c(30, 90, 180))
+
+ # --- CoxTime model ---
+ fit_ct <- survdnn(
+ Surv(time, status) ~ age + karno + celltype,
+ data = veteran,
+ loss = "coxtime",
+ epochs = 50,
+ verbose = FALSE,
+ .seed = 1
+ )
+ # By default, CoxTime survival predictions can use event times if times=NULL
+ S_ct <- predict(fit_ct, newdata = veteran, type = "survival")
+}
+}
+
}