From 67069597f19205caaf69138f8dad657232db9fde Mon Sep 17 00:00:00 2001
From: Imad EL BADISY
Date: Mon, 5 Jan 2026 16:19:08 +0100
Subject: [PATCH 1/5] tmp/loss-fixes (#18)
* update version to 0.7.0
* update the README
* update the README
* update news for CRAN submission
* update cran-comments
* update the readme
* fix AFT and CoxTime losses, predictions, and tests
---
DESCRIPTION | 4 +-
NEWS.md | 69 ++-
R/losses.R | 256 +++++++--
R/predict.survdnn.R | 256 +++++----
R/survdnn.R | 501 ++++++++++--------
README.Rmd | 171 ++++--
README.html | 298 +++++++----
README.md | 267 +++++++---
.../figure-gfm/unnamed-chunk-11-1.png | Bin 37312 -> 157083 bytes
.../figure-gfm/unnamed-chunk-12-1.png | Bin 0 -> 38404 bytes
cran-comments.md | 26 +-
tests/testthat/test-losses.R | 125 +++--
tests/testthat/test-predict.survdnn.R | 4 +-
tests/testthat/test-survdnn.R | 17 +-
14 files changed, 1288 insertions(+), 706 deletions(-)
create mode 100644 README_files/figure-gfm/unnamed-chunk-12-1.png
diff --git a/DESCRIPTION b/DESCRIPTION
index 60f19fb..68f79cb 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,6 +1,6 @@
Package: survdnn
-Title: Deep Neural Networks for Survival Analysis Using 'torch'
-Version: 0.6.3
+Title: Deep Neural Networks for Survival Analysis using 'torch'
+Version: 0.7.0
Authors@R:
person(given = "Imad",
family = "EL BADISY",
diff --git a/NEWS.md b/NEWS.md
index dc16f08..c85cfd0 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,19 +1,48 @@
-# survdnn 0.6.3
+# survdnn
+
+## survdnn 0.7.0
+
+### Major changes
+
+* Added full support for **training control mechanisms**, including early stopping callbacks and complete loss tracking across epochs.
+
+* Introduced `plot_loss()` to visualize training loss trajectories and diagnose convergence or instability.
+
+* Centralized **reproducibility control** via the `.seed` argument in `survdnn()`, synchronizing both R and Torch random number generators.
+
+* Expanded optimizer support to include **Adam, AdamW, SGD, RMSprop, and Adagrad**, with customizable optimizer arguments.
+
+* Enhanced **prediction methods** to robustly support linear predictors, survival probabilities, and cumulative risk across all supported loss functions.
+
+* Added explicit and user-controllable **missing-data handling** (`na_action = "omit"` or `"fail"`), with informative messages.
+
+### Minor changes
+
+* Improved handling of formulas using `Surv(...) ~ .` in prediction and evaluation.
+
+* Improved printing and summary methods for fitted `survdnn` objects.
+
+* Expanded unit test coverage, including optimizers, plotting utilities, and missing-data edge cases.
+
+### Bug fixes
+
+* Fixed inconsistencies in prediction and evaluation when formulas used `.` expansion.
+
+---
## survdnn 0.6.2
### Maintenance release (CRAN compliance)
- **Removed automatic `torch::install_torch()` on load:**
- The package no longer downloads or installs Torch libraries automatically when loaded.
- The `.onLoad()` function now performs only a silent availability check, and `.onAttach()`
- displays an informative message instructing users to manually run
- `torch::install_torch()` when necessary.
- This update ensures full compliance with CRAN policies that forbid modification of user environments
- or network activity during package load.
-- Updated startup message for clearer user guidance.
-- Internal documentation and version bump for CRAN resubmission.
+ The package no longer downloads or installs Torch libraries automatically when loaded. The `.onLoad()` function now performs only a silent availability check, and `.onAttach()` displays an informative message instructing users to manually run `torch::install_torch()` when necessary.
+
+ This update ensures full compliance with CRAN policies that forbid modification of user environments or network activity during package load.
+
+- Updated startup messages for clearer user guidance.
+
+- Internal documentation updates and version bump for CRAN resubmission.
---
@@ -21,31 +50,29 @@
### Infrastructure and testing improvements
-- Added conditional test skipping: tests and examples now use
- `skip_if_not(torch_is_installed())` and `skip_on_cran()` to avoid failures
- on systems where Torch is not available.
- (Thanks to @dfalbel for the PR.)
+- Added conditional test skipping: tests and examples now use `skip_if_not(torch_is_installed())` and `skip_on_cran()` to avoid failures on systems where Torch is not available (thanks to @dfalbel for the [PR](https://github.com/ielbadisy/survdnn/pull/5)).
- Regenerated documentation (`RoxygenNote: 7.3.3`) and updated man pages.
-- Minor internal consistency fixes and CI checks update.
+
+- Minor internal consistency fixes and CI check updates.
---
## survdnn 0.6.0
-First public release of `survdnn` — Deep Neural Networks for Survival Analysis in R using `torch`.
+First public release of `survdnn`.
### Features
- `survdnn()`: Fit deep learning survival models using a formula interface.
-- Supported loss functions:
- - Cox partial likelihood (`"cox"`)
- - L2-penalized Cox (`"cox_l2"`)
- - Time-dependent Cox (`"coxtime"`)
+- Supported loss functions:
+ - Cox partial likelihood (`"cox"`)
+ - L2-penalized Cox (`"cox_l2"`)
+ - Time-dependent Cox (`"coxtime"`)
- Accelerated Failure Time (`"aft"`)
-- Cross-validation with `cv_survdnn()`.
+- Cross-validation via `cv_survdnn()`.
- Hyperparameter tuning with `tune_survdnn()`.
- Survival probability prediction and curve plotting.
- Evaluation metrics: Concordance index (C-index), Brier score, and Integrated Brier Score (IBS).
-CRAN submission prepared. Includes README, documentation, and automated tests.
+CRAN submission prepared, including README, documentation, and automated tests.
diff --git a/R/losses.R b/R/losses.R
index 23c8e49..c892a52 100644
--- a/R/losses.R
+++ b/R/losses.R
@@ -1,16 +1,24 @@
#' Loss Functions for survdnn Models
#'
-#' These functions define various loss functions used internally by `survdnn()` for training deep neural networks on right-censored survival data.
+#' These functions define various loss functions used internally by `survdnn()`
+#' for training deep neural networks on right-censored survival data.
#'
#' @section Supported Losses:
#' - **Cox partial likelihood loss** (`cox_loss`): Negative partial log-likelihood used in proportional hazards modeling.
#' - **L2-penalized Cox loss** (`cox_l2_loss`): Adds L2 regularization to the Cox loss.
-#' - **Accelerated Failure Time (AFT) loss** (`aft_loss`): Mean squared error between predicted and log-transformed event times, applied to uncensored observations only.
-#' - **CoxTime loss** (`coxtime_loss`): Implements the partial likelihood loss from Kvamme & Borgan (2019), used in Cox-Time models.
+#' - **Accelerated Failure Time (AFT) loss** (`aft_loss`): Log-normal AFT **censored negative log-likelihood**
+#' (uses both events and censored observations).
+#' - **CoxTime loss** (`coxtime_loss`): Placeholder (see details). A correct CoxTime loss requires access to the network and the full input tensor.
#'
#' @param pred A tensor of predicted values (typically linear predictors or log-times).
#' @param true A tensor with two columns: observed time and status (1 = event, 0 = censored).
#' @param lambda Regularization parameter for `cox_l2_loss` (default: `1e-4`).
+#' @param sigma Positive numeric scale parameter for the log-normal AFT model (default: `1`).
+#' In `survdnn()`, a learnable global scale can be used via `survdnn__aft_lognormal_nll_factory()`.
+#' @param aft_loc Numeric scalar location offset for the AFT model on the log-time scale.
+#' When non-zero, the model is trained on centered log-times `log(time) - aft_loc` for better numerical stability.
+#' Prediction should add this offset back: `mu = mu_resid + aft_loc`.
+#' @param eps Small constant for numerical stability (default: `1e-12`).
#'
#' @return A scalar `torch_tensor` representing the loss value.
#' @name survdnn_losses
@@ -20,90 +28,240 @@
NULL
+# -------------------------------------------------------------------------
+# Internal utilities
+# -------------------------------------------------------------------------
+
+#' @keywords internal
+survdnn__zeros_like_scalar <- function(x) {
+ torch::torch_zeros_like(x$view(c(1)))[1]
+}
+
+#' @keywords internal
+survdnn__count_true <- function(mask) {
+ as.integer(mask$sum()$item())
+}
+
+#' @keywords internal
+survdnn__log_surv_std_normal <- function(z, eps = 1e-12) {
+ sqrt2 <- torch::torch_sqrt(torch::torch_tensor(2, dtype = z$dtype, device = z$device))
+ u <- z / sqrt2
+ S <- torch::torch_clamp(0.5 * (1 - torch::torch_erf(u)), min = eps)
+ torch::torch_log(S)
+}
+
+
+# -------------------------------------------------------------------------
+# Cox loss (keeps your sign convention: lp = -net(x))
+# -------------------------------------------------------------------------
+
#' @rdname survdnn_losses
#' @export
cox_loss <- function(pred, true) {
- time <- true[, 1]
+ time <- true[, 1]
status <- true[, 2]
- idx <- torch_argsort(time, descending = TRUE)
- time <- time[idx]
+ idx <- torch::torch_argsort(time, descending = TRUE)
status <- status[idx]
- pred <- -pred[idx, 1] # negate for log-partial likelihood
- log_cumsum_exp <- torch_logcumsumexp(pred, dim = 1)
+ lp <- -pred[idx, 1]
+
event_mask <- (status == 1)
+ if (survdnn__count_true(event_mask) == 0) {
+ return(survdnn__zeros_like_scalar(lp[1]))
+ }
- loss <- -torch_mean(pred[event_mask] - log_cumsum_exp[event_mask])
- loss
+ log_cumsum_exp <- torch::torch_logcumsumexp(lp, dim = 1)
+ -torch::torch_mean(lp[event_mask] - log_cumsum_exp[event_mask])
}
-
#' @rdname survdnn_losses
#' @export
cox_l2_loss <- function(pred, true, lambda = 1e-4) {
base_loss <- cox_loss(pred, true)
- l2_penalty <- lambda * torch_mean(pred^2)
+ lp <- -pred[, 1]
+ l2_penalty <- lambda * torch::torch_mean(lp^2)
base_loss + l2_penalty
}
+# -------------------------------------------------------------------------
+# AFT loss (Option B): log-normal AFT censored negative log-likelihood
+# -------------------------------------------------------------------------
+
#' @rdname survdnn_losses
#' @export
-aft_loss <- function(pred, true) {
- time <- true[, 1]
+aft_loss <- function(pred, true, sigma = 1, aft_loc = 0, eps = 1e-12) {
+
+ time <- true[, 1]
status <- true[, 2]
- log_time <- torch_log(time)
- event_mask <- (status == 1)
- n_events <- as.numeric(torch_sum(event_mask))
+ t <- torch::torch_clamp(time, min = eps)
+ lt <- torch::torch_log(t)
- if (n_events == 0) {
- return(torch_zeros_like(pred[1, 1])) ## this ensure the returned loss has the same device as pred & has the same dtype as pred (CPU/CUDA/MPS)
- }
+ mu_resid <- pred[, 1]
+
+ sigma_t <- torch::torch_tensor(
+ as.numeric(sigma),
+ dtype = mu_resid$dtype,
+ device = mu_resid$device
+ )
+ sigma_t <- torch::torch_clamp(sigma_t, min = eps)
+ log_sigma <- torch::torch_log(sigma_t)
+
+ aft_loc_t <- torch::torch_tensor(
+ as.numeric(aft_loc),
+ dtype = mu_resid$dtype,
+ device = mu_resid$device
+ )
+
+ lt_c <- lt - aft_loc_t
+ z <- (lt_c - mu_resid) / sigma_t
+
+ logS <- survdnn__log_surv_std_normal(z, eps = eps)
- pred_event <- pred[event_mask, 1]
- log_time_event <- log_time[event_mask]
+ nll_event <- lt + log_sigma + 0.5 * z^2
+ nll_cens <- -logS
- torch_mean((pred_event - log_time_event)^2)
+ nll <- torch::torch_where(status == 1, nll_event, nll_cens)
+ torch::torch_mean(nll)
}
+# -------------------------------------------------------------------------
+# CoxTime loss — cannot be correct with (pred, true) only
+# -------------------------------------------------------------------------
+
#' @rdname survdnn_losses
#' @export
coxtime_loss <- function(pred, true) {
+ stop(
+ "coxtime_loss(pred, true) is not identifiable from `pred` alone.\n",
+ "Cox-Time requires evaluating g(t_i, x_j) for all subjects j at each event time t_i.\n",
+ "Use the internal factory `survdnn__coxtime_loss_factory()` from survdnn() where `net` and the full input tensor are available.",
+ call. = FALSE
+ )
+}
- # `pred` is a tensor of shape [n, 1]: g(t_i, x_i)
- # `true` is a tensor with columns: time and status
- time <- true[, 1]
- status <- true[, 2]
- n <- time$size()[[1]]
+# -------------------------------------------------------------------------
+# Internal: Correct CoxTime loss factory
+#
+# IMPORTANT FIX:
+# - use `true[,1]` (RAW time) for sorting + risk sets
+# - use `x_tensor[,1]` (TIME AS FED TO NET; possibly scaled) when calling net
+# -------------------------------------------------------------------------
- # sorting by time descending
- idx <- torch_argsort(time, descending = TRUE)
- time <- time[idx]
- status <- status[idx]
- pred <- pred[idx, 1] # ensure shape [n]
+#' @keywords internal
+survdnn__coxtime_loss_factory <- function(net) {
- # compute risk set matrix: R_ij = 1 if time_j >= time_i
- time_i <- time$view(c(n, 1)) # [n, 1]
- time_j <- time$view(c(1, n)) # [1, n]
- risk_matrix <- (time_j >= time_i)$to(dtype = torch_float()) # [n, n]
+ force(net)
- # compute difference: g(t_i, x_j) - g(t_i, x_i)
- pred_i <- pred$view(c(n, 1)) # [n, 1]
- pred_j <- pred$view(c(1, n)) # [1, n]
- diff <- pred_j - pred_i # [n, n]
+ function(x_tensor, true, eps = 1e-12) {
- # mask for events only
- event_mask <- (status == 1)
+ time_raw <- true[, 1]
+ status <- true[, 2]
+ n <- time_raw$size()[[1]]
+
+ d <- x_tensor$size()[[2]]
+ if (d < 2) stop("CoxTime expects x_tensor with at least 2 columns: (time, x).", call. = FALSE)
- # compute log sum exp over risk set
- log_sum_exp <- torch_logsumexp(diff * risk_matrix, dim = 2) # [n]
+ time_inp <- x_tensor[, 1] # time as used by the net (can be raw or scaled)
+ x_cov <- x_tensor[, 2:d, drop = FALSE]
+
+ ## sort by RAW time (risk sets depend on raw ordering)
+ idx <- torch::torch_argsort(time_raw, descending = TRUE)
+
+ time_raw <- time_raw[idx]
+ time_inp <- time_inp[idx]
+ status <- status[idx]
+ x_cov <- x_cov[idx, , drop = FALSE]
+
+ event_mask <- (status == 1)
+ ne <- as.integer(event_mask$sum()$item())
+ if (ne == 0) return(torch::torch_zeros_like(time_raw[1]))
+
+ ## event times
+ t_event_raw <- time_raw[event_mask] # for risk sets
+ t_event_inp <- time_inp[event_mask] # for net input
+ x_event <- x_cov[event_mask, , drop = FALSE]
+
+ ## numerator: g(t_i, x_i) for events
+ inp_num <- torch::torch_cat(list(t_event_inp$unsqueeze(2), x_event), dim = 2)
+ g_num <- net(inp_num)[, 1]
+
+ ## denominator: for each event time t_i, evaluate g(t_i, x_j) for all j
+ p <- x_cov$size()[[2]]
+
+ x_rep <- x_cov$unsqueeze(2)$expand(c(n, ne, p))$permute(c(2, 1, 3))
+ t_rep <- t_event_inp$view(c(ne, 1, 1))$expand(c(ne, n, 1)) # time for net input
+ inp_den <- torch::torch_cat(list(t_rep, x_rep), dim = 3)
+
+ inp_den2 <- inp_den$reshape(c(ne * n, d))
+ g_den2 <- net(inp_den2)[, 1]
+ g_den <- g_den2$reshape(c(ne, n))
+
+ ## risk sets computed on RAW time
+ time_j <- time_raw$view(c(1, n))
+ t_i <- t_event_raw$view(c(ne, 1))
+ risk <- (time_j >= t_i)
+
+ neg_inf <- torch::torch_tensor(-Inf, dtype = g_den$dtype, device = g_den$device)
+ g_masked <- torch::torch_where(risk, g_den, neg_inf)
+
+ log_denom <- torch::torch_logsumexp(g_masked, dim = 2)
+ -torch::torch_mean(g_num - log_denom)
+ }
+}
+
+
+# -------------------------------------------------------------------------
+# Internal: AFT log-normal censored NLL factory (learnable global log(sigma))
+# with optional centering by aft_loc.
+# -------------------------------------------------------------------------
+
+#' @keywords internal
+survdnn__aft_lognormal_nll_factory <- function(device, aft_loc = 0) {
+
+ log_sigma <- torch::torch_tensor(
+ 0,
+ dtype = torch::torch_float(),
+ device = device,
+ requires_grad = TRUE
+ )
+
+ aft_loc_t <- torch::torch_tensor(
+ as.numeric(aft_loc),
+ dtype = torch::torch_float(),
+ device = device
+ )
+
+ loss_fn <- function(net, x, y, eps = 1e-12) {
+
+ time <- y[, 1]
+ status <- y[, 2]
+
+ mu_resid <- net(x)[, 1]
+
+ t <- torch::torch_clamp(time, min = eps)
+ lt <- torch::torch_log(t)
+
+ lt_c <- lt - aft_loc_t
+
+ sigma <- torch::torch_clamp(torch::torch_exp(log_sigma), min = eps)
+ z <- (lt_c - mu_resid) / sigma
+
+ logS <- survdnn__log_surv_std_normal(z, eps = eps)
+
+ nll_event <- lt + log_sigma + 0.5 * z^2
+ nll_cens <- -logS
+
+ nll <- torch::torch_where(status == 1, nll_event, nll_cens)
+ torch::torch_mean(nll)
+ }
- # final partial likelihood loss: mean over events only
- loss_terms <- log_sum_exp[event_mask]
- loss <- torch_mean(loss_terms)
- return(loss)
+ list(
+ loss_fn = loss_fn,
+ extra_params = list(log_sigma = log_sigma)
+ )
}
diff --git a/R/predict.survdnn.R b/R/predict.survdnn.R
index 518536a..da79d7a 100644
--- a/R/predict.survdnn.R
+++ b/R/predict.survdnn.R
@@ -6,12 +6,16 @@
#' @param object An object of class `"survdnn"` returned by [survdnn()].
#' @param newdata A data frame of new observations to predict on.
#' @param times Numeric vector of time points at which to compute survival or risk probabilities.
-#' Required if `type = "survival"` or `type = "risk"`.
+#' Required if `type = "survival"` or `type = "risk"` for Cox/AFT models.
+#' For CoxTime, `times = NULL` is allowed when `type="survival"` and defaults to event times.
#' @param type Character string specifying the type of prediction to return:
#' \describe{
-#' \item{"lp"}{Linear predictor (log-risk score; higher implies worse prognosis).}
+#' \item{"lp"}{Linear predictor. For `"cox"`/`"cox_l2"` this is a log-risk score
+#' (higher implies worse prognosis, consistent with training sign convention). For `"aft"`,
+#' this is the predicted location parameter \eqn{\mu(x)} on the log-time scale. For `"coxtime"`,
+#' this is \eqn{g(t_0, x)} evaluated at a reference time \eqn{t_0} (the first event time).}
#' \item{"survival"}{Predicted survival probabilities at each value of `times`.}
-#' \item{"risk"}{Cumulative risk (1 - survival) at a single time point.}
+#' \item{"risk"}{Cumulative risk (1 - survival) at **a single** time point.}
#' }
#' @param ... Currently ignored (for future extensions).
#'
@@ -19,24 +23,6 @@
#' (if `type = "survival"`) with one row per observation and one column per `times`.
#'
#' @export
-#'
-#' @examples
-#' \donttest{
-#' library(survival)
-#' data(veteran, package = "survival")
-#'
-#' mod <- survdnn(
-#' Surv(time, status) ~ age + karno + celltype,
-#' data = veteran,
-#' loss = "cox",
-#' epochs = 50,
-#' verbose = FALSE
-#' )
-#'
-#' predict(mod, newdata = veteran, type = "lp")[1:5]
-#' predict(mod, newdata = veteran, type = "survival", times = c(30, 90, 180))[1:5, ]
-#' predict(mod, newdata = veteran, type = "risk", times = 180)[1:5]
-#' }
predict.survdnn <- function(
object,
newdata,
@@ -53,12 +39,7 @@ predict.survdnn <- function(
loss <- object$loss
model <- object$model
- device <- if (!is.null(object$device)) {
- object$device
- } else {
- torch::torch_device("cpu")
- }
-
+ device <- if (!is.null(object$device)) object$device else torch::torch_device("cpu")
model$to(device = device)
model$eval()
@@ -76,11 +57,20 @@ predict.survdnn <- function(
scale = object$x_scale
)
+ ## IMPORTANT: type='risk' is defined at a single time point
+ if (type == "risk" && !is.null(times) && length(times) != 1) {
+ stop("For type = 'risk', `times` must be a single numeric value.", call. = FALSE)
+ }
+
## ================================================================
## Cox / Cox L2
## ================================================================
if (loss %in% c("cox", "cox_l2")) {
+ if (type %in% c("survival", "risk") && is.null(times)) {
+ stop("`times` must be specified for type = 'survival' or 'risk'.", call. = FALSE)
+ }
+
x_tensor <- torch::torch_tensor(
x_scaled,
dtype = torch::torch_float(),
@@ -93,14 +83,7 @@ predict.survdnn <- function(
if (type == "lp") return(lp)
- if (is.null(times)) {
- stop("`times` must be specified for type = 'survival' or 'risk'.")
- }
-
- if (type == "risk" && length(times) != 1) {
- stop("For type = 'risk', `times` must be a single value.")
- }
-
+ ## baseline hazard via Breslow on training data
train_x <- stats::model.matrix(
stats::delete.response(tt),
object$data
@@ -131,16 +114,12 @@ predict.survdnn <- function(
)
bh <- survival::basehaz(
- survival::coxph(Surv(time, status) ~ lp, data = train_df),
+ survival::coxph(survival::Surv(time, status) ~ lp, data = train_df),
centered = FALSE
)
- H0 <- stats::approx(
- bh$time,
- bh$hazard,
- xout = sort(times),
- rule = 2
- )$y
+ times_sorted <- sort(as.numeric(times))
+ H0 <- stats::approx(bh$time, bh$hazard, xout = times_sorted, rule = 2)$y
surv_mat <- outer(
lp,
@@ -148,19 +127,21 @@ predict.survdnn <- function(
function(lp_i, h0_j) exp(-h0_j * exp(lp_i))
)
- if (type == "risk") {
- return(1 - surv_mat[, length(times)])
- }
+ if (type == "risk") return(1 - surv_mat[, 1])
- colnames(surv_mat) <- paste0("t=", sort(times))
+ colnames(surv_mat) <- paste0("t=", times_sorted)
return(as.data.frame(surv_mat))
}
## ================================================================
- ## AFT
+ ## AFT (log-normal AFT with learned global sigma + training centering)
## ================================================================
if (loss == "aft") {
+ if (type %in% c("survival", "risk") && is.null(times)) {
+ stop("`times` must be specified for type = 'survival' or 'risk'.", call. = FALSE)
+ }
+
x_tensor <- torch::torch_tensor(
x_scaled,
dtype = torch::torch_float(),
@@ -168,29 +149,34 @@ predict.survdnn <- function(
)
torch::with_no_grad({
- pred <- as.numeric(model(x_tensor)[, 1])
+ mu_raw <- as.numeric(model(x_tensor)[, 1])
})
- if (type == "lp") return(pred)
+ loc <- if (!is.null(object$aft_loc) && is.finite(object$aft_loc)) object$aft_loc else 0
+ mu <- mu_raw + loc
+ if (type == "lp") return(mu)
- if (is.null(times)) {
- y_train <- model.response(model.frame(object$formula, object$data))
- times <- sort(unique(y_train[, "time"]))
- }
+ ## sigma: must be finite and > 0; otherwise fall back to 1
+ ls <- object$aft_log_sigma
+ sigma <- if (!is.null(ls) && is.finite(ls)) exp(ls) else 1
+ if (!is.finite(sigma) || sigma <= 0) sigma <- 1
- logt <- log(sort(times))
+ times_sorted <- sort(as.numeric(times))
+ times_sorted[times_sorted <= 0] <- .Machine$double.eps
+ logt <- log(times_sorted)
surv_mat <- outer(
- pred,
+ mu,
logt,
- function(fx, lt) 1 - pnorm(lt - fx)
+ function(mu_i, lt) 1 - stats::pnorm((lt - mu_i) / sigma)
)
- if (type == "risk") {
- return(1 - surv_mat[, length(times)])
- }
+ surv_mat[surv_mat < 0] <- 0
+ surv_mat[surv_mat > 1] <- 1
- colnames(surv_mat) <- paste0("t=", sort(times))
+ if (type == "risk") return(1 - surv_mat[, 1])
+
+ colnames(surv_mat) <- paste0("t=", times_sorted)
return(as.data.frame(surv_mat))
}
@@ -200,11 +186,27 @@ predict.survdnn <- function(
if (loss == "coxtime") {
y_train <- model.response(model.frame(object$formula, object$data))
-
time_train <- y_train[, "time"]
status_train <- y_train[, "status"]
event_times <- sort(unique(time_train[status_train == 1]))
+ if (length(event_times) == 0) {
+ stop("CoxTime prediction requires at least one event in training data.", call. = FALSE)
+ }
+
+ ## --- time scaling used in training (fallback-safe) ---
+ t_center <- if (!is.null(object$coxtime_time_center) && is.finite(object$coxtime_time_center)) {
+ as.numeric(object$coxtime_time_center)
+ } else 0
+
+ t_scale <- if (!is.null(object$coxtime_time_scale) && is.finite(object$coxtime_time_scale) &&
+ as.numeric(object$coxtime_time_scale) > 0) {
+ as.numeric(object$coxtime_time_scale)
+ } else 1
+
+ scale_t <- function(t) (as.numeric(t) - t_center) / t_scale
+
+ ## training covariates (scaled) for baseline computation
train_x <- stats::model.matrix(
stats::delete.response(tt),
object$data
@@ -216,117 +218,95 @@ predict.survdnn <- function(
scale = object$x_scale
)
- if (length(event_times) == 0) {
- stop(
- "CoxTime prediction requires at least one event in training data.",
- call. = FALSE
- )
- }
-
- ## type = "lp"
+ ## type = "lp": define lp at a reference time (first event time)
if (type == "lp") {
-
t0 <- event_times[1]
- x_temp <- cbind(t0, x_scaled)
-
- x_tensor <- torch::torch_tensor(
- x_temp,
- dtype = torch::torch_float(),
- device = device
- )
-
+ t0s <- scale_t(t0)
+ x_temp <- cbind(t0s, x_scaled)
+ x_tensor <- torch::torch_tensor(x_temp, dtype = torch::torch_float(), device = device)
torch::with_no_grad({
lp <- as.numeric(model(x_tensor)[, 1])
})
-
return(lp)
}
- if (is.null(times)) times <- event_times
- times <- sort(unique(times))
-
- ## g(T_i, x_new)
- g_new_mat <- matrix(
- NA_real_,
- nrow = nrow(x_scaled),
- ncol = length(event_times)
- )
+ ## For CoxTime: allow times=NULL for survival -> default event_times (RAW)
+ if (type == "survival" && is.null(times)) {
+ times_sorted <- event_times
+ } else {
+ if (type %in% c("survival", "risk") && is.null(times)) {
+ stop("`times` must be specified for type = 'survival' or 'risk'.", call. = FALSE)
+ }
+ times_sorted <- sort(unique(as.numeric(times)))
+ }
+ ## ------------------------------------------------------------
+ ## Compute g(t_k, x_new) on event-time grid
+ ## NOTE: net expects SCALED time input
+ ## ------------------------------------------------------------
+ g_new_mat <- matrix(NA_real_, nrow = nrow(x_scaled), ncol = length(event_times))
for (j in seq_along(event_times)) {
-
- x_temp <- cbind(event_times[j], x_scaled)
-
- x_tensor <- torch::torch_tensor(
- x_temp,
- dtype = torch::torch_float(),
- device = device
- )
-
+ tj <- event_times[j]
+ tjs <- scale_t(tj)
+ x_temp <- cbind(tjs, x_scaled)
+ x_tensor <- torch::torch_tensor(x_temp, dtype = torch::torch_float(), device = device)
torch::with_no_grad({
g_new_mat[, j] <- as.numeric(model(x_tensor)[, 1])
})
}
- ## g(T_i, x_train)
- g_train_mat <- matrix(
- NA_real_,
- nrow = nrow(train_x_scaled),
- ncol = length(event_times)
- )
-
+ ## ------------------------------------------------------------
+ ## Compute g(t_k, x_train) on event-time grid (scaled time input)
+ ## ------------------------------------------------------------
+ g_train_mat <- matrix(NA_real_, nrow = nrow(train_x_scaled), ncol = length(event_times))
for (j in seq_along(event_times)) {
-
- x_temp <- cbind(event_times[j], train_x_scaled)
-
- x_tensor <- torch::torch_tensor(
- x_temp,
- dtype = torch::torch_float(),
- device = device
- )
-
+ tj <- event_times[j]
+ tjs <- scale_t(tj)
+ x_temp <- cbind(tjs, train_x_scaled)
+ x_tensor <- torch::torch_tensor(x_temp, dtype = torch::torch_float(), device = device)
torch::with_no_grad({
g_train_mat[, j] <- as.numeric(model(x_tensor)[, 1])
})
}
- dN <- as.numeric(table(factor(time_train[status_train == 1], levels = event_times)))
- denom <- colSums(exp(g_train_mat), na.rm = TRUE)
- dH0 <- dN / denom
-
- H_pred <- matrix(
- NA_real_,
- nrow = nrow(g_new_mat),
- ncol = length(times)
- )
-
- for (i in seq_along(times)) {
-
- relevant <- which(event_times <= times[i])
-
+ ## ------------------------------------------------------------
+ ## Baseline increments: dH0(t_k) = dN(t_k) / sum_{j in R(t_k)} exp(g(t_k, x_j))
+ ## risk sets are defined on RAW time (correct)
+ ## ------------------------------------------------------------
+ dN <- as.numeric(table(factor(time_train[status_train == 1], levels = event_times)))
+
+ risk_mat <- outer(time_train, event_times, `>=`)
+ denom <- colSums(exp(g_train_mat) * risk_mat, na.rm = TRUE)
+
+ denom[denom <= 0] <- NA_real_
+ dH0 <- dN / denom
+ dH0[is.na(dH0)] <- 0
+
+ ## ------------------------------------------------------------
+ ## Cumulative hazard at requested times (RAW time grid)
+ ## ------------------------------------------------------------
+ H_pred <- matrix(0, nrow = nrow(g_new_mat), ncol = length(times_sorted))
+ for (i in seq_along(times_sorted)) {
+ relevant <- which(event_times <= times_sorted[i])
if (length(relevant) > 0) {
-
H_pred[, i] <- rowSums(
exp(g_new_mat[, relevant, drop = FALSE]) *
- matrix(
- rep(dH0[relevant], each = nrow(g_new_mat)),
- nrow = nrow(g_new_mat)
- )
+ matrix(rep(dH0[relevant], each = nrow(g_new_mat)), nrow = nrow(g_new_mat))
)
-
} else {
H_pred[, i] <- 0
}
}
S_pred <- exp(-H_pred)
+ S_pred[S_pred < 0] <- 0
+ S_pred[S_pred > 1] <- 1
- if (type == "risk") {
- return(1 - S_pred[, length(times)])
- }
+ if (type == "risk") return(1 - S_pred[, 1])
- colnames(S_pred) <- paste0("t=", times)
+ colnames(S_pred) <- paste0("t=", times_sorted)
return(as.data.frame(S_pred))
}
- stop("Unsupported loss type for prediction: ", loss)
+ stop("Unsupported loss type for prediction: ", loss, call. = FALSE)
}
diff --git a/R/survdnn.R b/R/survdnn.R
index 44c4aff..af66dd2 100644
--- a/R/survdnn.R
+++ b/R/survdnn.R
@@ -17,41 +17,41 @@
#' @keywords internal
#' @export
build_dnn <- function(input_dim,
- hidden,
- activation = "relu",
- output_dim = 1L,
- dropout = 0.3,
- batch_norm = TRUE) {
-
- layers <- list()
- in_features <- input_dim
-
- act_fn <- switch(
- activation,
- relu = torch::nn_relu,
- leaky_relu = torch::nn_leaky_relu,
- tanh = torch::nn_tanh,
- sigmoid = torch::nn_sigmoid,
- gelu = torch::nn_gelu,
- elu = torch::nn_elu,
- softplus = torch::nn_softplus,
- stop("Unsupported activation function: ", activation)
- )
-
- for (h in hidden) {
- layers <- append(layers, list(torch::nn_linear(in_features, h)))
- if (isTRUE(batch_norm)) {
- layers <- append(layers, list(torch::nn_batch_norm1d(h)))
- }
- layers <- append(layers, list(act_fn()))
- if (!is.null(dropout) && dropout > 0) {
- layers <- append(layers, list(torch::nn_dropout(p = dropout)))
- }
- in_features <- h
- }
-
- layers <- append(layers, list(torch::nn_linear(in_features, output_dim)))
- torch::nn_sequential(!!!layers)
+ hidden,
+ activation = "relu",
+ output_dim = 1L,
+ dropout = 0.3,
+ batch_norm = TRUE) {
+
+layers <- list()
+in_features <- input_dim
+
+act_fn <- switch(
+activation,
+relu = torch::nn_relu,
+leaky_relu = torch::nn_leaky_relu,
+tanh = torch::nn_tanh,
+sigmoid = torch::nn_sigmoid,
+gelu = torch::nn_gelu,
+elu = torch::nn_elu,
+softplus = torch::nn_softplus,
+stop("Unsupported activation function: ", activation)
+)
+
+for (h in hidden) {
+layers <- append(layers, list(torch::nn_linear(in_features, h)))
+if (isTRUE(batch_norm)) {
+layers <- append(layers, list(torch::nn_batch_norm1d(h)))
+}
+layers <- append(layers, list(act_fn()))
+if (!is.null(dropout) && dropout > 0) {
+layers <- append(layers, list(torch::nn_dropout(p = dropout)))
+}
+in_features <- h
+}
+
+layers <- append(layers, list(torch::nn_linear(in_features, output_dim)))
+torch::nn_sequential(!!!layers)
}
@@ -107,197 +107,248 @@ build_dnn <- function(input_dim,
#' \item{optimizer}{Optimizer name used.}
#' \item{optim_args}{List of optimizer arguments used.}
#' \item{device}{Torch device used for training (`torch_device`).}
+#' \item{aft_log_sigma}{Learned global log(sigma) for `loss="aft"`; `NA_real_` otherwise.}
+#' \item{aft_loc}{AFT log-time location offset used for centering when `loss="aft"`; `NA_real_` otherwise.}
+#' \item{coxtime_time_center}{Mean used to scale time for CoxTime; `NA_real_` otherwise.}
+#' \item{coxtime_time_scale}{SD used to scale time for CoxTime; `NA_real_` otherwise.}
#' }
#' @export
survdnn <- function(formula, data,
- hidden = c(32L, 16L),
- activation = "relu",
- lr = 1e-4,
- epochs = 300L,
- loss = c("cox", "cox_l2", "aft", "coxtime"),
- optimizer = c("adam", "adamw", "sgd", "rmsprop", "adagrad"),
- optim_args = list(),
- verbose = TRUE,
- dropout = 0.3,
- batch_norm = TRUE,
- callbacks = NULL,
- .seed = NULL,
- .device = c("auto", "cpu", "cuda"),
- na_action = c("omit", "fail")) {
-
- survdnn_set_seed(.seed)
-
- device <- survdnn_get_device(.device)
-
- loss <- match.arg(loss)
- optimizer <- match.arg(optimizer)
- na_action <- match.arg(na_action)
-
- if (!is.list(optim_args)) {
- stop("`optim_args` must be a list (possibly empty).", call. = FALSE)
- }
-
- if (!is.null(callbacks)) {
- if (is.function(callbacks)) {
- callbacks <- list(callbacks)
- } else if (!is.list(callbacks) || !all(vapply(callbacks, is.function, logical(1)))) {
- stop("`callbacks` must be NULL, a function, or a list of functions.", call. = FALSE)
- }
- }
-
- stopifnot(inherits(formula, "formula"))
- stopifnot(is.data.frame(data))
-
- loss_fn <- switch(
- loss,
- cox = cox_loss,
- cox_l2 = function(pred, true) cox_l2_loss(pred, true, lambda = 1e-3),
- aft = aft_loss,
- coxtime = coxtime_loss
- )
-
- environment(formula) <- list2env(
- list(Surv = survival::Surv),
- parent = environment(formula)
- )
-
- ## explicit missing-data handling
- n_before <- nrow(data)
-
- mf <- model.frame(
- formula,
- data = data,
- na.action = if (na_action == "omit") stats::na.omit else stats::na.fail
- )
-
- n_after <- nrow(mf)
- n_removed <- n_before - n_after
-
- if (n_removed > 0 && isTRUE(verbose) && na_action == "omit") {
- message(sprintf("Removed %d observations with missing values.", n_removed))
- }
-
- y <- model.response(mf)
- x <- model.matrix(attr(mf, "terms"), data = mf)[, -1, drop = FALSE]
- time <- y[, "time"]
- status <- y[, "status"]
- x_scaled <- scale(x)
-
- x_tensor <- if (loss == "coxtime") {
- torch::torch_tensor(
- cbind(time, x_scaled),
- dtype = torch::torch_float(),
- device = device
- )
- } else {
- torch::torch_tensor(
- x_scaled,
- dtype = torch::torch_float(),
- device = device
- )
- }
-
- y_tensor <- torch::torch_tensor(
- cbind(time, status),
- dtype = torch::torch_float(),
- device = device
- )
-
- ## build network with dropout + batch_norm controls
- net <- build_dnn(
- input_dim = ncol(x_tensor),
- hidden = hidden,
- activation = activation,
- output_dim = 1L,
- dropout = dropout,
- batch_norm = batch_norm
- )
- net$to(device = device)
-
- ## build optimizer with dispatcher
- opt_args <- c(
- list(params = net$parameters, lr = lr),
- optim_args
- )
-
- ## default weight decay for adam/adamw if not provided
- if (is.null(optim_args$weight_decay) && optimizer %in% c("adam", "adamw")) {
- opt_args$weight_decay <- 1e-4
- }
-
- optimizer_obj <- switch(
- optimizer,
- adam = do.call(torch::optim_adam, opt_args),
- adamw = do.call(torch::optim_adamw, opt_args),
- sgd = do.call(torch::optim_sgd, opt_args),
- rmsprop = do.call(torch::optim_rmsprop, opt_args),
- adagrad = do.call(torch::optim_adagrad, opt_args),
- stop("Unsupported optimizer: ", optimizer)
- )
-
- loss_history <- numeric(epochs)
- early_stopped <- FALSE
- last_epoch_run <- epochs
-
- for (epoch in 1:epochs) {
- net$train()
- optimizer_obj$zero_grad()
-
- pred <- net(x_tensor)
- loss_val <- loss_fn(pred, y_tensor)
-
- loss_val$backward()
- optimizer_obj$step()
-
- current_loss <- loss_val$item()
- loss_history[epoch] <- current_loss
- last_epoch_run <- epoch
-
- if (verbose && epoch %% 50 == 0) {
- cat(sprintf("Epoch %d - Loss: %.6f\n", epoch, current_loss))
- cat("\n")
- }
-
- ## callbacks
- if (!is.null(callbacks)) {
- for (cb in callbacks) {
- stop_now <- isTRUE(cb(epoch, current_loss))
- if (stop_now) {
- early_stopped <- TRUE
- break
- }
- }
- if (early_stopped) break
- }
- }
-
- ## truncate loss history if early stopping
- if (early_stopped && last_epoch_run < epochs) {
- loss_history <- loss_history[seq_len(last_epoch_run)]
- }
-
- structure(
- list(
- model = net,
- formula = formula,
- data = data,
- xnames = colnames(x),
- x_center = attr(x_scaled, "scaled:center"),
- x_scale = attr(x_scaled, "scaled:scale"),
- loss_history = loss_history,
- final_loss = tail(loss_history, 1),
- loss = loss,
- activation = activation,
- hidden = hidden,
- lr = lr,
- epochs = epochs,
- optimizer = optimizer,
- optim_args = optim_args,
- device = device,
- dropout = dropout,
- batch_norm = batch_norm,
- na_action = na_action
- ),
- class = "survdnn"
- )
+hidden = c(32L, 16L),
+activation = "relu",
+lr = 1e-4,
+epochs = 300L,
+loss = c("cox", "cox_l2", "aft", "coxtime"),
+optimizer = c("adam", "adamw", "sgd", "rmsprop", "adagrad"),
+optim_args = list(),
+verbose = TRUE,
+dropout = 0.3,
+batch_norm = TRUE,
+callbacks = NULL,
+.seed = NULL,
+.device = c("auto", "cpu", "cuda"),
+na_action = c("omit", "fail")) {
+
+survdnn_set_seed(.seed)
+device <- survdnn_get_device(.device)
+
+loss <- match.arg(loss)
+optimizer <- match.arg(optimizer)
+na_action <- match.arg(na_action)
+
+if (!is.list(optim_args)) {
+stop("`optim_args` must be a list (possibly empty).", call. = FALSE)
+}
+
+if (!is.null(callbacks)) {
+if (is.function(callbacks)) {
+callbacks <- list(callbacks)
+} else if (!is.list(callbacks) || !all(vapply(callbacks, is.function, logical(1)))) {
+stop("`callbacks` must be NULL, a function, or a list of functions.", call. = FALSE)
+}
+}
+
+stopifnot(inherits(formula, "formula"))
+stopifnot(is.data.frame(data))
+
+environment(formula) <- list2env(
+list(Surv = survival::Surv),
+parent = environment(formula)
+)
+
+# ---- missing data handling ----
+n_before <- nrow(data)
+mf <- model.frame(
+formula,
+data = data,
+na.action = if (na_action == "omit") stats::na.omit else stats::na.fail
+)
+n_after <- nrow(mf)
+n_removed <- n_before - n_after
+if (n_removed > 0 && isTRUE(verbose) && na_action == "omit") {
+message(sprintf("Removed %d observations with missing values.", n_removed))
+}
+
+y <- model.response(mf)
+x <- model.matrix(attr(mf, "terms"), data = mf)[, -1, drop = FALSE]
+time <- y[, "time"]
+status <- y[, "status"]
+x_scaled <- scale(x)
+
+# ---- AFT location offset for stability ----
+aft_loc <- NA_real_
+if (loss == "aft") {
+evt <- (status == 1)
+if (any(evt)) {
+aft_loc <- mean(log(pmax(time[evt], .Machine$double.eps)))
+} else {
+aft_loc <- mean(log(pmax(time, .Machine$double.eps)))
+}
+if (!is.finite(aft_loc)) aft_loc <- 0
+}
+
+# ---- CoxTime time scaling (CRITICAL for heterogeneity) ----
+coxtime_time_center <- NA_real_
+coxtime_time_scale <- NA_real_
+time_scaled <- NULL
+
+if (loss == "coxtime") {
+ts <- scale(as.numeric(time))
+coxtime_time_center <- as.numeric(attr(ts, "scaled:center"))
+coxtime_time_scale <- as.numeric(attr(ts, "scaled:scale"))
+if (!is.finite(coxtime_time_scale) || coxtime_time_scale <= 0) coxtime_time_scale <- 1
+time_scaled <- as.numeric(ts)
+}
+
+# ---- tensors ----
+# x_tensor:
+# - coxtime: [time_scaled, x_scaled] (time as fed to net)
+# - others : [x_scaled]
+x_tensor <- if (loss == "coxtime") {
+torch::torch_tensor(
+cbind(time_scaled, x_scaled),
+dtype = torch::torch_float(),
+device = device
+)
+} else {
+torch::torch_tensor(
+x_scaled,
+dtype = torch::torch_float(),
+device = device
+)
+}
+
+# y_tensor always uses RAW time for ordering/risk sets
+y_tensor <- torch::torch_tensor(
+cbind(time, status),
+dtype = torch::torch_float(),
+device = device
+)
+
+# ---- network ----
+net <- build_dnn(
+input_dim = ncol(x_tensor),
+hidden = hidden,
+activation = activation,
+output_dim = 1L,
+dropout = dropout,
+batch_norm = batch_norm
+)
+net$to(device = device)
+
+# ---- loss dispatcher + (optional) AFT extra params ----
+extra_params <- NULL # list for AFT, NULL otherwise
+aft_log_sigma <- NA_real_ # ALWAYS numeric
+loss_fn <- NULL
+
+if (loss == "cox") {
+loss_fn <- function(net, x, y) cox_loss(net(x), y)
+} else if (loss == "cox_l2") {
+loss_fn <- function(net, x, y) cox_l2_loss(net(x), y, lambda = 1e-3)
+} else if (loss == "aft") {
+loc0 <- if (is.finite(aft_loc)) aft_loc else 0
+aft_bundle <- survdnn__aft_lognormal_nll_factory(device = device, aft_loc = loc0)
+extra_params <- aft_bundle$extra_params
+loss_fn <- function(net, x, y) aft_bundle$loss_fn(net, x, y)
+} else if (loss == "coxtime") {
+lf <- survdnn__coxtime_loss_factory(net)
+loss_fn <- function(net, x, y) lf(x, y)
+} else {
+stop("Unsupported loss: ", loss, call. = FALSE)
+}
+
+# ---- optimizer params ----
+params <- net$parameters
+if (loss == "aft" && !is.null(extra_params) && !is.null(extra_params$log_sigma)) {
+params <- c(params, list(extra_params$log_sigma))
+}
+
+opt_args <- c(list(params = params, lr = lr), optim_args)
+
+if (is.null(optim_args$weight_decay) && optimizer %in% c("adam", "adamw")) {
+opt_args$weight_decay <- 1e-4
+}
+
+optimizer_obj <- switch(
+optimizer,
+adam = do.call(torch::optim_adam, opt_args),
+adamw = do.call(torch::optim_adamw, opt_args),
+sgd = do.call(torch::optim_sgd, opt_args),
+rmsprop = do.call(torch::optim_rmsprop, opt_args),
+adagrad = do.call(torch::optim_adagrad, opt_args),
+stop("Unsupported optimizer: ", optimizer)
+)
+
+# ---- training loop ----
+loss_history <- numeric(epochs)
+early_stopped <- FALSE
+last_epoch_run <- epochs
+
+for (epoch in 1:epochs) {
+net$train()
+optimizer_obj$zero_grad()
+
+loss_val <- loss_fn(net, x_tensor, y_tensor)
+loss_val$backward()
+optimizer_obj$step()
+
+current_loss <- loss_val$item()
+loss_history[epoch] <- current_loss
+last_epoch_run <- epoch
+
+if (verbose && epoch %% 50 == 0) {
+cat(sprintf("Epoch %d - Loss: %.6f\n\n", epoch, current_loss))
+}
+
+if (!is.null(callbacks)) {
+for (cb in callbacks) {
+if (isTRUE(cb(epoch, current_loss))) {
+early_stopped <- TRUE
+break
+}
+}
+if (early_stopped) break
+}
+}
+
+if (early_stopped && last_epoch_run < epochs) {
+loss_history <- loss_history[seq_len(last_epoch_run)]
+}
+
+# ---- store learned AFT log(sigma) robustly ----
+if (loss == "aft" && !is.null(extra_params) && !is.null(extra_params$log_sigma)) {
+aft_log_sigma <- as.numeric(extra_params$log_sigma$item())
+if (!is.finite(aft_log_sigma)) aft_log_sigma <- NA_real_
+} else {
+aft_log_sigma <- NA_real_
+}
+
+structure(
+list(
+model = net,
+formula = formula,
+data = data,
+xnames = colnames(x),
+x_center = attr(x_scaled, "scaled:center"),
+x_scale = attr(x_scaled, "scaled:scale"),
+loss_history = loss_history,
+final_loss = tail(loss_history, 1),
+loss = loss,
+activation = activation,
+hidden = hidden,
+lr = lr,
+epochs = epochs,
+optimizer = optimizer,
+optim_args = optim_args,
+device = device,
+dropout = dropout,
+batch_norm = batch_norm,
+na_action = na_action,
+aft_log_sigma = aft_log_sigma,
+aft_loc = if (loss == "aft") aft_loc else NA_real_,
+coxtime_time_center = if (loss == "coxtime") coxtime_time_center else NA_real_,
+coxtime_time_scale = if (loss == "coxtime") coxtime_time_scale else NA_real_
+),
+class = "survdnn"
+)
}
diff --git a/README.Rmd b/README.Rmd
index 493dc8e..14faaf1 100644
--- a/README.Rmd
+++ b/README.Rmd
@@ -4,31 +4,40 @@ output: github_document
# survdnn
-> Deep Neural Networks for Survival Analysis Using [torch](https://torch.mlverse.org/)
+> Deep Neural Networks for Survival Analysis using [torch](https://torch.mlverse.org/)
[](LICENSE)
[](https://github.com/ielbadisy/survdnn/actions/workflows/R-CMD-check.yaml)
----
-`survdnn` implements neural network-based models for right-censored survival analysis using the native `torch` backend in R. It supports multiple loss functions including Cox partial likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives, as well as time-dependent extension such as Cox-Time. The package provides a formula interface, supports model evaluation using time-dependent metrics (e.g., C-index, Brier score, IBS), cross-validation, and hyperparameter tuning.
+## About
+
+`survdnn` implements neural network-based models for right-censored survival analysis using the native `torch` backend in R. It supports multiple loss functions including Cox partial likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives, as well as time-dependent extension such as Cox-Time. The package provides a formula interface, supports model evaluation using time-dependent metrics (C-index, Brier score, IBS), cross-validation, and hyperparameter tuning.
+
+## Review status
+
+A methodological paper describing the design, implementation, and evaluation of `survdnn` is currently under review at _The R Journal_.
----
-## Features
+## Main features
- Formula interface for `Surv() ~ .` models
-- Modular neural architectures: configurable layers, activations, and losses
+
+- Modular neural architectures: configurable layers, activations, optimizers, and losses
+
- Built-in survival loss functions:
+
- `"cox"`: Cox partial likelihood
- `"cox_l2"`: penalized Cox
- `"aft"`: Accelerated Failure Time
- - `"coxtime"`: deep time-dependent Cox (like DeepSurv)
-- Evaluation: C-index, Brier score, Integrated Brier Score (IBS)
+ - `"coxtime"`: deep time-dependent Cox
+
+- Evaluation: C-index, Brier score, IBS
+
- Model selection with `cv_survdnn()` and `tune_survdnn()`
+
- Prediction of survival curves via `predict()` and `plot()`
----
## Installation
@@ -47,9 +56,8 @@ setwd("survdnn")
devtools::install()
```
----
-## Quick Example
+## Quick example
```{r, message = FALSE, warning = FALSE}
library(survdnn)
@@ -62,13 +70,11 @@ mod <- survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
hidden = c(32, 16),
- epochs = 100,
+ epochs = 300,
loss = "cox",
verbose = TRUE
)
-```
-```{r}
summary(mod)
```
@@ -76,62 +82,64 @@ summary(mod)
plot(mod, group_by = "celltype", times = 1:300)
```
----
## Loss Functions
+- Cox partial likelihood
+
```{r}
-# Cox partial likelihood
mod1 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "cox",
- epochs = 100
+ epochs = 300
)
```
+- Accelerated Failure Time
+
```{r}
-# Accelerated Failure Time
mod2 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "aft",
- epochs = 100
+ epochs = 300
)
```
+- Coxtime
+
```{r}
-# Deep time-dependent Cox (Coxtime)
mod3 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "coxtime",
- epochs = 100
+ epochs = 300
)
```
----
-## Cross-Validation
+## Cross-validation
```{r, eval = FALSE}
cv_results <- cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
- times = c(30, 90, 180),
+ times = c(600),
metrics = c("cindex", "ibs"),
folds = 3,
hidden = c(16, 8),
loss = "cox",
- epochs = 100
+ epochs = 300
)
+
print(cv_results)
```
----
-## Hyperparameter Tuning
+
+## Hyperparameter tuning
```{r, eval = FALSE}
grid <- list(
@@ -152,23 +160,56 @@ tune_res <- tune_survdnn(
refit = FALSE,
return = "summary"
)
+
print(tune_res)
```
----
-## Plot Survival Curves
+
+## Tuning and refitting the best Model
+
+`tune_survdnn()` can be used also to automatically refit the best-performing model on the full dataset. This behavior is controlled by the `refit` and `return` arguments. For example:
+
+```{r, eval = FALSE}
+best_model <- tune_survdnn(
+ formula = Surv(time, status) ~ age + karno + celltype,
+ data = veteran,
+ times = c(90, 300),
+ metrics = "cindex",
+ param_grid = grid,
+ folds = 3,
+ refit = TRUE,
+ return = "best_model"
+ )
+```
+
+
+In this mode, cross-validation is used to select the optimal hyperparameter configuration, after which the selected model is refitted on the full dataset. The function then returns a fitted object of class `"survdnn"`.
+
+The resulting model can be used directly for prediction visualization, and evaluation:
+
+```{r, eval = FALSE}
+summary(best_model)
+
+plot(best_model, times = 1:300)
+
+predict(best_model, veteran, type = "risk", times = 180)
+```
+
+This makes `tune_survdnn()` suitable for end-to-end workflows, combining model selection and final model fitting.
+
+
+
+## Plot survival curves
```{r}
plot(mod1, group_by = "celltype", times = 1:300)
```
-
```{r}
plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
```
----
## Documentation
@@ -180,43 +221,81 @@ help(package = "survdnn")
?plot.survdnn
```
----
-
## Testing
```{r, eval = FALSE}
-# Run all tests
+# run all tests
devtools::test()
```
+## Note on reproducibility
----
+By default, `{torch}` initializes model weights and shuffles minibatches using random draws, so results may differ across runs. Unlike `set.seed()`, which only controls R's random number generator, `{torch}` relies on its own RNG implemented in C++ (and CUDA when using GPUs).
+
+To ensure reproducibility, random seeds must therefore be set at the Torch level as well.
+
+`survdnn` provides built-in control of randomness to guarantee reproducible results across runs. The main fitting function, `survdnn()`, exposes a dedicated `.seed` argument:
+
+```{r, eval = FALSE}
+mod <- survdnn(
+ Surv(time, status) ~ age + karno + celltype,
+ data = veteran,
+ epochs = 300,
+ .seed = 123
+)
+```
+
+When `.seed` is provided, `survdnn()` internally synchronizes both R and Torch random number generators via `survdnn_set_seed()`, ensuring reproducible:
+
+* weight initialization
+
+* dropout behavior
+
+* minibatch ordering
-## Reproducibility
+* loss trajectories
-By default, Torch initializes model weights and shuffles minibatches with random draws, so results may differ at each run.
-Unlike `set.seed()`, which only controls R’s RNG, `{torch}` uses its own RNG implemented in C++/CUDA. To ensure reproducibility, set the Torch seed before training:
+If `.seed = NULL` (the default), randomness is left uncontrolled and results may vary between runs.
+
+For full reproducibility in cross-validation or hyperparameter tuning, the same `.seed` mechanism is propagated internally by `cv_survdnn()` and `tune_survdnn()`, ensuring consistent data splits, model initialization, and optimization paths across repetitions.
+
+## CPU and core usage
+
+`survdnn` relies on the `{torch}` backend for numerical computation. The number of CPU cores (threads) used during training, prediction, and evaluation is controlled globally by Torch.
+
+By default, Torch automatically configures its CPU thread pools based on the available system resources, unless explicitly overridden by the user using:
```{r}
-torch::torch_manual_seed(123)
+torch::torch_set_num_threads(4)
```
----
+This setting affects:
+
+* model training
+
+* prediction
+
+* evaluation metrics
+
+* cross-validation and hyperparameter tuning
+
+GPU acceleration can be enabled by setting `.device = "cuda"` when calling `survdnn()` (`cv_survdnn()` and `tune_survdnn()` too).
+
## Availability
-The `survdnn` R package is available on CRAN or at: https://github.com/ielbadisy/survdnn
+The `survdnn` R package is available on [CRAN](https://CRAN.R-project.org/package=survdnn) or [github](https://github.com/ielbadisy/survdnn)
----
## Contributions
-Contributions, issues, and feature requests are welcome.
-Open an [issue](https://github.com/ielbadisy/survdnn/issues) or submit a pull request!
+Contributions, issues, and feature requests are welcome!
+
+Open an [issue](https://github.com/ielbadisy/survdnn/issues) or submit a pull request.
----
## License
-MIT © [Imad El Badisy](mailto:elbadisyimad@gmail.com)
+MIT License © 2025 Imad EL BADISY
+
diff --git a/README.html b/README.html
index b4c5347..346a150 100644
--- a/README.html
+++ b/README.html
@@ -604,39 +604,41 @@
survdnn

-Deep Neural Networks for Survival Analysis Using torch
+Deep Neural Networks for Survival Analysis using torch


-
+About
survdnn implements neural network-based models for
right-censored survival analysis using the native torch
backend in R. It supports multiple loss functions including Cox partial
likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives,
as well as time-dependent extension such as Cox-Time. The package
provides a formula interface, supports model evaluation using
-time-dependent metrics (e.g., C-index, Brier score, IBS),
-cross-validation, and hyperparameter tuning.
-
-Features
+time-dependent metrics (C-index, Brier score, IBS), cross-validation,
+and hyperparameter tuning.
+Review status
+A methodological paper describing the design, implementation, and
+evaluation of survdnn is currently under review at The
+R Journal.
+Main features
-- Formula interface for
Surv() ~ . models
-- Modular neural architectures: configurable layers, activations, and
-losses
-- Built-in survival loss functions:
+
Formula interface for Surv() ~ . models
+Modular neural architectures: configurable layers, activations,
+optimizers, and losses
+Built-in survival loss functions:
"cox": Cox partial likelihood
"cox_l2": penalized Cox
"aft": Accelerated Failure Time
-"coxtime": deep time-dependent Cox (like DeepSurv)
+"coxtime": deep time-dependent Cox
-- Evaluation: C-index, Brier score, Integrated Brier Score (IBS)
-- Model selection with
cv_survdnn() and
-tune_survdnn()
-- Prediction of survival curves via
predict() and
-plot()
+Evaluation: C-index, Brier score, IBS
+Model selection with cv_survdnn() and
+tune_survdnn()
+Prediction of survival curves via predict() and
+plot()
-
Installation
# Install from CRAN
install.packages("surdnn")
@@ -650,8 +652,7 @@ Installation
git clone https://github.com/ielbadisy/survdnn.git
setwd("survdnn")
devtools::install()
-
-Quick Example
+Quick example
library(survdnn)
library(survival, quietly = TRUE)
library(ggplot2)
@@ -662,32 +663,38 @@ Quick Example
Surv(time, status) ~ age + karno + celltype,
data = veteran,
hidden = c(32, 16),
- epochs = 100,
+ epochs = 300,
loss = "cox",
verbose = TRUE
)
-## Epoch 50 - Loss: 3.898330
-## Epoch 100 - Loss: 3.834461
+## Epoch 50 - Loss: 3.983201
+##
+## Epoch 100 - Loss: 3.947356
+##
+## Epoch 150 - Loss: 3.934828
+##
+## Epoch 200 - Loss: 3.876191
+##
+## Epoch 250 - Loss: 3.813223
+##
+## Epoch 300 - Loss: 3.868888
##
-
-## ── Summary of survdnn model ─────────────────────────────────────────────────────────────────────
-
-##
## Formula:
## Surv(time, status) ~ age + karno + celltype
-## <environment: 0x57f5687daa00>
+## <environment: 0x611459d0ec80>
##
## Model architecture:
## Hidden layers: 32 : 16
## Activation: relu
## Dropout: 0.3
-## Final loss: 3.834461
+## Final loss: 3.868888
##
## Training summary:
-## Epochs: 100
+## Epochs: 300
## Learning rate: 1e-04
## Loss function: cox
+## Optimizer: adam
##
## Data summary:
## Observations: 137
@@ -695,50 +702,81 @@ Quick Example
## Time range: [ 1, 999 ]
## Event rate: 93.4%
plot(mod, group_by = "celltype", times = 1:300)
-
Loss Functions
-# Cox partial likelihood
-mod1 <- survdnn(
- Surv(time, status) ~ age + karno,
- data = veteran,
- loss = "cox",
- epochs = 100
- )
-## Epoch 50 - Loss: 3.991873
-## Epoch 100 - Loss: 3.937163
-# Accelerated Failure Time
-mod2 <- survdnn(
- Surv(time, status) ~ age + karno,
- data = veteran,
- loss = "aft",
- epochs = 100
- )
-## Epoch 50 - Loss: 18.660992
-## Epoch 100 - Loss: 18.260056
-# Deep time-dependent Cox (Coxtime)
-mod3 <- survdnn(
- Surv(time, status) ~ age + karno,
- data = veteran,
- loss = "coxtime",
- epochs = 100
- )
-## Epoch 50 - Loss: 4.899240
-## Epoch 100 - Loss: 4.835490
-
-Cross-Validation
+
+- Cox partial likelihood
+
+mod1 <- survdnn(
+ Surv(time, status) ~ age + karno,
+ data = veteran,
+ loss = "cox",
+ epochs = 300
+ )
+## Epoch 50 - Loss: 3.986035
+##
+## Epoch 100 - Loss: 3.973183
+##
+## Epoch 150 - Loss: 3.944867
+##
+## Epoch 200 - Loss: 3.901533
+##
+## Epoch 250 - Loss: 3.849433
+##
+## Epoch 300 - Loss: 3.899746
+
+- Accelerated Failure Time
+
+mod2 <- survdnn(
+ Surv(time, status) ~ age + karno,
+ data = veteran,
+ loss = "aft",
+ epochs = 300
+ )
+## Epoch 50 - Loss: 18.154217
+##
+## Epoch 100 - Loss: 17.844833
+##
+## Epoch 150 - Loss: 17.560537
+##
+## Epoch 200 - Loss: 17.134348
+##
+## Epoch 250 - Loss: 16.840366
+##
+## Epoch 300 - Loss: 16.344124
+
+mod3 <- survdnn(
+ Surv(time, status) ~ age + karno,
+ data = veteran,
+ loss = "coxtime",
+ epochs = 300
+ )
+## Epoch 50 - Loss: 4.932558
+##
+## Epoch 100 - Loss: 4.864682
+##
+## Epoch 150 - Loss: 4.830169
+##
+## Epoch 200 - Loss: 4.784954
+##
+## Epoch 250 - Loss: 4.764827
+##
+## Epoch 300 - Loss: 4.731824
+Cross-validation
cv_results <- cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
- times = c(30, 90, 180),
+ times = c(600),
metrics = c("cindex", "ibs"),
folds = 3,
hidden = c(16, 8),
loss = "cox",
- epochs = 100
+ epochs = 300
)
-print(cv_results)
-
-Hyperparameter Tuning
+
+print(cv_results)
+Hyperparameter tuning
grid <- list(
hidden = list(c(16), c(32, 16)),
lr = c(1e-3),
@@ -757,42 +795,112 @@ Hyperparameter Tuning
refit = FALSE,
return = "summary"
)
-print(tune_res)
-
-Plot Survival Curves
-plot(mod1, group_by = "celltype", times = 1:300)
-
-plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
-
-
+
+print(tune_res)
+Tuning and refitting the
+best Model
+tune_survdnn() can be used also to automatically refit
+the best-performing model on the full dataset. This behavior is
+controlled by the refit and return arguments.
+For example:
+best_model <- tune_survdnn(
+ formula = Surv(time, status) ~ age + karno + celltype,
+ data = veteran,
+ times = c(90, 300),
+ metrics = "cindex",
+ param_grid = grid,
+ folds = 3,
+ refit = TRUE,
+ return = "best_model"
+ )
+In this mode, cross-validation is used to select the optimal
+hyperparameter configuration, after which the selected model is refitted
+on the full dataset. The function then returns a fitted object of class
+"survdnn".
+The resulting model can be used directly for prediction
+visualization, and evaluation:
+summary(best_model)
+
+plot(best_model, times = 1:300)
+
+predict(best_model, veteran, type = "risk", times = 180)
+This makes tune_survdnn() suitable for end-to-end
+workflows, combining model selection and final model fitting.
+Plot survival curves
+plot(mod1, group_by = "celltype", times = 1:300)
+
+plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
+
Documentation
-help(package = "survdnn")
-?survdnn
-?tune_survdnn
-?cv_survdnn
-?plot.survdnn
-
+help(package = "survdnn")
+?survdnn
+?tune_survdnn
+?cv_survdnn
+?plot.survdnn
Testing
-# Run all tests
-devtools::test()
-
-Reproducibility
-By default, Torch initializes model weights and shuffles minibatches
-with random draws, so results may differ at each run.
-Unlike set.seed(), which only controls R’s RNG,
-{torch} uses its own RNG implemented in C++/CUDA. To ensure
-reproducibility, set the Torch seed before training:
-torch::torch_manual_seed(123)
-
+# run all tests
+devtools::test()
+Note on reproducibility
+By default, {torch} initializes model weights and
+shuffles minibatches using random draws, so results may differ across
+runs. Unlike set.seed(), which only controls R’s random
+number generator, {torch} relies on its own RNG implemented
+in C++ (and CUDA when using GPUs).
+To ensure reproducibility, random seeds must therefore be set at the
+Torch level as well.
+survdnn provides built-in control of randomness to
+guarantee reproducible results across runs. The main fitting function,
+survdnn(), exposes a dedicated .seed
+argument:
+mod <- survdnn(
+ Surv(time, status) ~ age + karno + celltype,
+ data = veteran,
+ epochs = 300,
+ .seed = 123
+)
+When .seed is provided, survdnn()
+internally synchronizes both R and Torch random number generators via
+survdnn_set_seed(), ensuring reproducible:
+
+weight initialization
+dropout behavior
+minibatch ordering
+loss trajectories
+
+If .seed = NULL (the default), randomness is left
+uncontrolled and results may vary between runs.
+For full reproducibility in cross-validation or hyperparameter
+tuning, the same .seed mechanism is propagated internally
+by cv_survdnn() and tune_survdnn(), ensuring
+consistent data splits, model initialization, and optimization paths
+across repetitions.
+CPU and core usage
+survdnn relies on the {torch} backend for
+numerical computation. The number of CPU cores (threads) used during
+training, prediction, and evaluation is controlled globally by
+Torch.
+By default, Torch automatically configures its CPU thread pools based
+on the available system resources, unless explicitly overridden by the
+user using:
+torch::torch_set_num_threads(4)
+This setting affects:
+
+GPU acceleration can be enabled by setting
+.device = "cuda" when calling survdnn()
+(cv_survdnn() and tune_survdnn() too).
Availability
-The survdnn R package is available on CRAN or at: https://github.com/ielbadisy/survdnn
-
+The survdnn R package is available on CRAN or github
Contributions
-Contributions, issues, and feature requests are welcome. Open an issue or submit a
-pull request!
-
+Contributions, issues, and feature requests are welcome!
+Open an issue or submit a
+pull request.
License
-MIT © Imad El Badisy
+MIT License © 2025 Imad EL BADISY