diff --git a/R/colocboost_check_update_jk.R b/R/colocboost_check_update_jk.R index 4a224c8..bdc1f45 100644 --- a/R/colocboost_check_update_jk.R +++ b/R/colocboost_check_update_jk.R @@ -5,40 +5,22 @@ #' #' @return update_status and real_update_jk for each trait #' @noRd -colocboost_check_update_jk <- function(cb_model, cb_model_para, cb_data, - prioritize_jkstar = TRUE, - jk_equiv_corr = 0.8, ##### more than 2 traits - jk_equiv_loglik = 1, ## more than 2 traits - func_compare = "min_max", ##### more than 3 traits - coloc_thresh = 0.1) { +colocboost_check_update_jk <- function(cb_model, cb_model_para, cb_data) { + pos.update <- which(cb_model_para$update_y == 1) focal_outcome_idx <- cb_model_para$focal_outcome_idx if (is.null(focal_outcome_idx)) { - cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data, - prioritize_jkstar = prioritize_jkstar, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik, - func_compare = func_compare, - coloc_thresh = coloc_thresh - ) + cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data) } else { if (focal_outcome_idx %in% pos.update) { - cb_model_para <- boost_check_update_jk_focal(cb_model, cb_model_para, cb_data, - prioritize_jkstar = prioritize_jkstar, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik, - func_compare = func_compare, - coloc_thresh = coloc_thresh, + cb_model_para <- boost_check_update_jk_focal( + cb_model, + cb_model_para, + cb_data, focal_outcome_idx = focal_outcome_idx ) } else { - cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data, - prioritize_jkstar = prioritize_jkstar, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik, - func_compare = func_compare, - coloc_thresh = coloc_thresh - ) + cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data) } } return(cb_model_para) @@ -47,12 +29,8 @@ colocboost_check_update_jk <- function(cb_model, cb_model_para, cb_data, #' @importFrom stats median -boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data, - prioritize_jkstar = TRUE, - jk_equiv_corr = 0.8, ##### more than 2 traits - jk_equiv_loglik = 1, ## more than 2 traits - func_compare = "min_max", ##### more than 3 traits - coloc_thresh = 0.1) { +boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data) { + ############# Output ################# # we will obtain and update the cb_model_para$update_status and cb_model_para$real_update_jk ###################################### @@ -61,9 +39,11 @@ boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data, update_status <- rep(0, cb_model_para$L) update_jk <- rep(NA, cb_model_para$L + 1) real_update_jk <- rep(NA, cb_model_para$L) - if (is.null(cb_model_para$coloc_thresh)) { - cb_model_para$coloc_thresh <- (1 - coloc_thresh) * max(sapply(1:length(cb_model), function(i) max(cb_model[[i]]$change_loglike))) - } + # - initial parameter + prioritize_jkstar <- cb_model_para$prioritize_jkstar + jk_equiv_corr <- cb_model_para$jk_equiv_corr + jk_equiv_loglik <- cb_model_para$jk_equiv_loglik + func_compare <- cb_model_para$func_compare # - update only Ys which is not stop pos.update <- which(cb_model_para$update_y == 1) @@ -328,12 +308,8 @@ boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data, boost_check_update_jk_focal <- function(cb_model, cb_model_para, cb_data, - prioritize_jkstar = TRUE, - jk_equiv_corr = 0.8, ##### more than 2 traits - jk_equiv_loglik = 1, ## more than 2 traits - func_compare = "min_max", ##### more than 3 traits - coloc_thresh = 0.1, focal_outcome_idx = 1) { + ############# Output ################# # we will obtain and update the cb_model_para$update_status and cb_model_para$real_update_jk ###################################### @@ -345,6 +321,12 @@ boost_check_update_jk_focal <- function(cb_model, cb_model_para, cb_data, if (is.null(cb_model_para$coloc_thresh)) { cb_model_para$coloc_thresh <- (1 - coloc_thresh) * max(sapply(1:length(cb_model), function(i) max(cb_model[[i]]$change_loglike))) } + # - initial parameter + prioritize_jkstar <- cb_model_para$prioritize_jkstar + jk_equiv_corr <- cb_model_para$jk_equiv_corr + jk_equiv_loglik <- cb_model_para$jk_equiv_loglik + func_compare <- cb_model_para$func_compare + coloc_thresh <- cb_model_para$coloc_thresh # - update only Ys which is not stop pos.update <- which(cb_model_para$update_y == 1) diff --git a/R/colocboost_init.R b/R/colocboost_init.R index 1223ecd..f44da5d 100644 --- a/R/colocboost_init.R +++ b/R/colocboost_init.R @@ -30,7 +30,7 @@ colocboost_inits <- function() { colocboost_init_data <- function(X, Y, dict_YX, Z, LD, N_sumstat, dict_sumstatLD, Var_y, SeBhat, - keep_variables = NULL, + keep_variables, focal_outcome_idx = NULL, focal_outcome_variables = TRUE, overlap_variables = FALSE, @@ -213,12 +213,20 @@ colocboost_init_model <- function(cb_data, #' @importFrom utils tail colocboost_init_para <- function(cb_data, cb_model, tau = 0.01, func_simplex = "z2z", - lambda = 0.5, lambda_focal_outcome = 1, + lambda = 0.5, + lambda_focal_outcome = 1, + learning_rate_decay = 1, multi_test_thresh = 1, func_multi_test = "lfdr", LD_free = FALSE, outcome_names = NULL, - focal_outcome_idx = NULL) { + focal_outcome_idx = NULL, + dynamic_learning_rate = TRUE, + prioritize_jkstar = TRUE, + jk_equiv_corr = 0.8, + jk_equiv_loglik = 1, + func_compare = "min_max", + coloc_thresh = 0.1) { ################# initialization ####################################### # - sample size N <- sapply(cb_data$data, function(dt) dt$N) @@ -251,7 +259,8 @@ colocboost_init_para <- function(cb_data, cb_model, tau = 0.01, } else { outcome_names <- paste0("Y", 1:L) } - + # - initial coloc_thresh + coloc_thresh <- (1 - coloc_thresh) * max(sapply(1:length(cb_model), function(i) max(cb_model[[i]]$change_loglike))) cb_model_para <- list( "L" = L, "P" = P, @@ -260,7 +269,14 @@ colocboost_init_para <- function(cb_data, cb_model, tau = 0.01, "func_simplex" = func_simplex, "lambda" = lambda, "lambda_focal_outcome" = lambda_focal_outcome, + "learning_rate_decay" = learning_rate_decay, "profile_loglike" = profile_loglike, + "dynamic_learning_rate" = dynamic_learning_rate, + "prioritize_jkstar" = prioritize_jkstar, + "jk_equiv_corr" = jk_equiv_corr, + "jk_equiv_loglik" = jk_equiv_loglik, + "func_compare" = func_compare, + "coloc_thresh" = coloc_thresh, "update_status" = c(), "jk" = c(), "update_y" = update_y, diff --git a/R/colocboost_one_causal.R b/R/colocboost_one_causal.R index 031b67b..e680fc9 100644 --- a/R/colocboost_one_causal.R +++ b/R/colocboost_one_causal.R @@ -13,37 +13,13 @@ #' @rdname colocboost_one_causal #' @keywords cb_one_causal #' @noRd -colocboost_one_causal <- function(cb_model, cb_model_para, cb_data, - jk_equiv_corr = 0.8, - jk_equiv_loglik = 1, - tau = 0.01, - learning_rate_decay = 1, - func_simplex = "z2z", - lambda = 0.5, - lambda_focal_outcome = 1, - LD_free = FALSE) { - if (jk_equiv_corr != 0) { - cb_obj <- colocboost_one_iteration(cb_model, cb_model_para, cb_data, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik, - tau = tau, - learning_rate_decay = learning_rate_decay, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome, - LD_free = LD_free - ) +colocboost_one_causal <- function(cb_model, cb_model_para, cb_data) { + + if (cb_model_para$jk_equiv_corr != 0) { + cb_obj <- colocboost_one_iteration(cb_model, cb_model_para, cb_data) + } else { - cb_obj <- colocboost_diagLD(cb_model, cb_model_para, cb_data, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik, - tau = tau, - learning_rate_decay = learning_rate_decay, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome, - LD_free = LD_free - ) + cb_obj <- colocboost_diagLD(cb_model, cb_model_para, cb_data) } return(cb_obj) } @@ -64,21 +40,15 @@ colocboost_one_causal <- function(cb_model, cb_model_para, cb_data, #' #' @keywords cb_one_causal #' @noRd -colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data, - jk_equiv_corr = 0.8, - jk_equiv_loglik = 1, - tau = 0.01, - learning_rate_decay = 1, - func_simplex = "z2z", - lambda = 0.5, - lambda_focal_outcome = 1, - LD_free = FALSE) { +colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data) { + if (sum(cb_model_para$update_y == 1) != 0) { ######## - some traits updated # - step 1: check update clusters - real_update <- boost_check_update_jk_one_causal(cb_model, cb_model_para, cb_data, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik + real_update <- boost_check_update_jk_one_causal( + cb_model, + cb_model_para, + cb_data ) # - step 2: boost update @@ -89,13 +59,10 @@ colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data, cb_model_para$update_status <- cbind(cb_model_para$update_status, as.matrix(real_update[[i_update]]$update_status)) cb_model_para$real_update_jk <- rbind(cb_model_para$real_update_jk, real_update[[i_update]]$real_update_jk) # - update cb_model - cb_model <- colocboost_update(cb_model, cb_model_para, cb_data, - tau = tau, - learning_rate_decay = learning_rate_decay, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome, - LD_free = LD_free + cb_model <- colocboost_update( + cb_model, + cb_model_para, + cb_data ) } } @@ -130,10 +97,8 @@ colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data, #' #' @keywords cb_one_causal #' @noRd -boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data, - prioritize_jkstar = TRUE, - jk_equiv_corr = 0.8, - jk_equiv_loglik = 1) { +boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data) { + pos.update <- which(cb_model_para$update_y == 1) update_jk <- rep(NA, cb_model_para$L + 1) # - check pairwise equivalent of jk @@ -169,8 +134,8 @@ boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data, model_update = model_update, cb_data = cb_data, X_dict = X_dict, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik + jk_equiv_corr = cb_model_para$jk_equiv_corr, + jk_equiv_loglik = cb_model_para$jk_equiv_loglik ) # define category with same jk temp <- sapply(1:nrow(change_each_pair), function(x) { @@ -214,15 +179,8 @@ boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data, #' #' @keywords cb_one_causal #' @noRd -colocboost_diagLD <- function(cb_model, cb_model_para, cb_data, - jk_equiv_corr = 0, - jk_equiv_loglik = 0.1, - tau = 0.01, - learning_rate_decay = 1, - func_simplex = "z2z", - lambda = 0.5, - lambda_focal_outcome = 1, - LD_free = FALSE) { +colocboost_diagLD <- function(cb_model, cb_model_para, cb_data) { + if (sum(cb_model_para$update_y == 1) == 1) { pos.update <- which(cb_model_para$update_y == 1) update_jk <- rep(NA, cb_model_para$L + 1) @@ -242,9 +200,10 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data, cb_model_para$jk <- rbind(cb_model_para$jk, update_jk) cb_model_para$update_status <- cbind(cb_model_para$update_status, as.matrix(update_status)) cb_model_para$real_update_jk <- rbind(cb_model_para$real_update_jk, real_update_jk) - cb_model <- colocboost_update(cb_model, cb_model_para, cb_data, - tau = tau, learning_rate_decay = learning_rate_decay, func_simplex = func_simplex, - lambda = lambda, lambda_focal_outcome = lambda_focal_outcome, LD_free = LD_free + cb_model <- colocboost_update( + cb_model, + cb_model_para, + cb_data ) } @@ -273,9 +232,10 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data, "real_update_jk" = real_update_jk ) # - update cb_model - cb_model_tmp <- colocboost_update(cb_model_tmp, cb_model_para, cb_data, - tau = tau, learning_rate_decay = learning_rate_decay, func_simplex = func_simplex, - lambda = lambda, lambda_focal_outcome = lambda_focal_outcome, LD_free = LD_free + cb_model_tmp <- colocboost_update( + cb_model_tmp, + cb_model_para, + cb_data ) weights <- rbind(weights, cb_model_tmp[[iy]]$weights_path) } @@ -286,8 +246,8 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data, model_update = model_update, cb_data = cb_data, X_dict = X_dict, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik + jk_equiv_corr = cb_model_para$jk_equiv_corr, + jk_equiv_loglik = cb_model_para$jk_equiv_loglik ) change_each_pair <- change_each_pair * overlap_pair # define category with same jk @@ -327,13 +287,10 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data, cb_model_para$update_status <- cbind(cb_model_para$update_status, as.matrix(real_update[[i_update]]$update_status)) cb_model_para$real_update_jk <- rbind(cb_model_para$real_update_jk, real_update[[i_update]]$real_update_jk) # - update cb_model - cb_model <- colocboost_update(cb_model, cb_model_para, cb_data, - tau = tau, - learning_rate_decay = learning_rate_decay, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome, - LD_free = LD_free + cb_model <- colocboost_update( + cb_model, + cb_model_para, + cb_data ) } } diff --git a/R/colocboost_update.R b/R/colocboost_update.R index 46cfdfc..9109b66 100644 --- a/R/colocboost_update.R +++ b/R/colocboost_update.R @@ -6,17 +6,12 @@ #' @importFrom utils head tail #' @return colocboost object after gradient boosting update #' @noRd -colocboost_update <- function(cb_model, cb_model_para, cb_data, - tau = 0.01, - learning_rate_decay = 1, - func_simplex = "z2z", - lambda = 0.5, - lambda_focal_outcome = 1, - LD_free = FALSE, - dynamic_learning_rate = TRUE) { +colocboost_update <- function(cb_model, cb_model_para, cb_data) { + # - clear which outcome need to be updated at which jk pos.update <- which(cb_model_para$update_temp$update_status != 0) focal_outcome_idx <- cb_model_para$focal_outcome_idx + tau = cb_model_para$tau for (i in pos.update) { update_jk <- cb_model_para$update_temp$real_update_jk[i] @@ -45,14 +40,17 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data, # - calculate delta if (is.null(focal_outcome_idx)) { - lambda_outcome <- lambda + lambda_outcome <- cb_model_para$lambda } else { - lambda_outcome <- ifelse(i == focal_outcome_idx, lambda_focal_outcome, lambda) + lambda_outcome <- ifelse(i == focal_outcome_idx, + cb_model_para$lambda_focal_outcome, + cb_model_para$lambda) } delta <- boost_KL_delta( z = cb_model[[i]]$z, ld_feature = ld_feature, adj_dep = adj_dep, - func_simplex = func_simplex, lambda = lambda_outcome + func_simplex = cb_model_para$func_simplex, + lambda = lambda_outcome ) x_tmp <- cb_data$data[[X_dict]]$X @@ -62,7 +60,7 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data, } else { cb_model[[i]]$res / scaling_factor } - obj_ld <- if (LD_free) ld_feature else rep(1, length(ld_feature)) + obj_ld <- if (cb_model_para$LD_free) ld_feature else rep(1, length(ld_feature)) if (length(cb_data$data[[i]]$variable_miss) != 0) { obj_ld[cb_data$data[[i]]$variable_miss] <- 0 } @@ -83,9 +81,9 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data, ########## BEGIN: MAIN UPDATE ###################### # - Gradient ascent on beta beta_grad <- weights * sign(cb_model[[i]]$correlation) - if (dynamic_learning_rate) { + if (cb_model_para$dynamic_learning_rate) { if (tail(cb_model[[i]]$obj_path, n = 1) > 0.5) { - step1 <- max(0.5 * (1 / (1 + learning_rate_decay * (length(cb_model[[i]]$obj_path) - 1))), cb_model[[i]]$learning_rate_init) + step1 <- max(0.5 * (1 / (1 + cb_model_para$learning_rate_decay * (length(cb_model[[i]]$obj_path) - 1))), cb_model[[i]]$learning_rate_init) } else { step1 <- cb_model[[i]]$learning_rate_init } @@ -218,14 +216,11 @@ boost_check_stop <- function(cb_model, cb_model_para, pos_stop, } -boost_obj_last <- function(cb_data, cb_model, cb_model_para, - tau = 0.01, - func_simplex = "z2z", - lambda = 0.5, - lambda_focal_outcome = 1, - LD_free = TRUE) { +boost_obj_last <- function(cb_data, cb_model, cb_model_para) { + pos.stop <- cb_model_para$true_stop focal_outcome_idx <- cb_model_para$focal_outcome_idx + tau <- cb_model_para$tau for (i in pos.stop) { # - check which jk update @@ -250,14 +245,17 @@ boost_obj_last <- function(cb_data, cb_model, cb_model_para, # - calculate delta if (is.null(focal_outcome_idx)) { - lambda_outcome <- lambda + lambda_outcome <- cb_model_para$lambda } else { - lambda_outcome <- ifelse(i == focal_outcome_idx, lambda_focal_outcome, lambda) + lambda_outcome <- ifelse(i == focal_outcome_idx, + cb_model_para$lambda_focal_outcome, + cb_model_para$lambda) } delta <- boost_KL_delta( z = cb_model[[i]]$z, ld_feature = ld_feature, adj_dep = adj_dep, - func_simplex = func_simplex, lambda = lambda_outcome + func_simplex = cb_model_para$func_simplex, + lambda = lambda_outcome ) x_tmp <- cb_data$data[[X_dict]]$X @@ -268,7 +266,7 @@ boost_obj_last <- function(cb_data, cb_model, cb_model_para, cb_model[[i]]$res / scaling_factor } - obj_ld <- if (LD_free) ld_feature else rep(1, length(ld_feature)) + obj_ld <- if (cb_model_para$LD_free) ld_feature else rep(1, length(ld_feature)) if (length(cb_data$data[[i]]$variable_miss) != 0) { obj_ld[cb_data$data[[i]]$variable_miss] <- 0 } diff --git a/R/colocboost_workhorse.R b/R/colocboost_workhorse.R index 5d20fe0..e81928b 100644 --- a/R/colocboost_workhorse.R +++ b/R/colocboost_workhorse.R @@ -18,7 +18,7 @@ colocboost_workhorse <- function(cb_data, prioritize_jkstar = TRUE, learning_rate_init = 0.01, learning_rate_decay = 1, - func_simplex = "z2z", + func_simplex = "LD_z2z", lambda = 0.5, lambda_focal_outcome = 1, jk_equiv_corr = 0.8, @@ -56,15 +56,20 @@ colocboost_workhorse <- function(cb_data, func_simplex = func_simplex, lambda = lambda, lambda_focal_outcome = lambda_focal_outcome, + learning_rate_decay = learning_rate_decay, multi_test_thresh = multi_test_thresh, func_multi_test = func_multi_test, LD_free = LD_free, outcome_names = outcome_names, - focal_outcome_idx = focal_outcome_idx + focal_outcome_idx = focal_outcome_idx, + dynamic_learning_rate = dynamic_learning_rate, + prioritize_jkstar = prioritize_jkstar, + jk_equiv_corr = jk_equiv_corr, + jk_equiv_loglik = jk_equiv_loglik, + func_compare = func_compare, + coloc_thresh = coloc_thresh ) - - - + if (is.null(M)) { M <- cb_model_para$L * 200 } @@ -102,16 +107,7 @@ colocboost_workhorse <- function(cb_data, if (M == 1) { # single effect with or without LD matrix message("Running colocboost with assumption of one causal per outcome!") - cb_obj <- colocboost_one_causal(cb_model, cb_model_para, cb_data, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik, - tau = tau, - learning_rate_decay = learning_rate_decay, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome, - LD_free = LD_free - ) + cb_obj <- colocboost_one_causal(cb_model, cb_model_para, cb_data) cb_obj$cb_model_para$coveraged <- "one_causal" } else { # - add more iterations for more outcomes @@ -120,24 +116,10 @@ colocboost_workhorse <- function(cb_data, break } else { # step 1: check which outcomes need to be updated at which variant - cb_model_para <- colocboost_check_update_jk(cb_model, cb_model_para, cb_data, - prioritize_jkstar = prioritize_jkstar, - jk_equiv_corr = jk_equiv_corr, - jk_equiv_loglik = jk_equiv_loglik, - func_compare = func_compare, - coloc_thresh = coloc_thresh - ) + cb_model_para <- colocboost_check_update_jk(cb_model, cb_model_para, cb_data) # step 2: gradient boosting for the updated outcomes - cb_model <- colocboost_update(cb_model, cb_model_para, cb_data, - tau = tau, - learning_rate_decay = learning_rate_decay, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome, - LD_free = LD_free, - dynamic_learning_rate = dynamic_learning_rate - ) + cb_model <- colocboost_update(cb_model, cb_model_para, cb_data) # step 3: check stop for the updated ones # # - update cb_model and cb_model_parameter @@ -208,12 +190,7 @@ colocboost_workhorse <- function(cb_data, if (!is.null(cb_model_para$true_stop)) { ####### --------------------------------------------- # calculate objective function of Y for the last iteration. - cb_model <- boost_obj_last(cb_data, cb_model, cb_model_para, - tau = tau, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome - ) + cb_model <- boost_obj_last(cb_data, cb_model, cb_model_para) if (!is.null(focal_outcome_idx)) { if (focal_outcome_idx %in% cb_model_para$true_stop) { message(paste( @@ -261,12 +238,7 @@ colocboost_workhorse <- function(cb_data, cb_model_para$true_stop <- which(cb_model_para$update_y == 1) ####### --------------------------------------------- # calculate objective function of Y for the last iteration. - cb_model <- boost_obj_last(cb_data, cb_model, cb_model_para, - tau = tau, - func_simplex = func_simplex, - lambda = lambda, - lambda_focal_outcome = lambda_focal_outcome - ) + cb_model <- boost_obj_last(cb_data, cb_model, cb_model_para) warning(paste("COLOC-BOOST updates did not converge in", M, "iterations; checkpoint at last iteration")) cb_model_para$coveraged <- FALSE } diff --git a/tests/testthat/test_corner_cases.R b/tests/testthat/test_corner_cases.R index 2e8065b..037198d 100644 --- a/tests/testthat/test_corner_cases.R +++ b/tests/testthat/test_corner_cases.R @@ -118,14 +118,18 @@ test_that("colocboost handles different sample sizes", { # Test with proper row names to handle different sample sizes X <- test_data$X rownames(X) <- paste0("sample", 1:nrow(X)) - Y1 <- test_data$Y[[1]] - Y2 <- test_data$Y[[2]] + Y1 <- as.matrix(test_data$Y[[1]]) + Y2 <- as.matrix(test_data$Y[[2]]) rownames(Y1) <- paste0("sample", 1:length(Y1)) rownames(Y2) <- paste0("sample", 1:length(Y2)) + Y <- list(Y1, Y2) # Skip test if function fails in unexpected ways - # (This test may require more complex setup than we can do here) - skip("Requires specialized setup for different sample sizes") + # (This test may require more complex setup than we can do here) + # skip("Requires specialized setup for different sample sizes") + # Can handle on Apr 17, 2025 (News) + # Run colocboost - should only warning + expect_warning(colocboost(X = X, Y = Y)) }) # Test colocboost with different variant sets @@ -135,9 +139,6 @@ test_that("colocboost handles different variant sets", { # Generate data with different variant sets test_data <- generate_edge_case_data("different_variants") - # Need to create dict_YX for this case - dict_YX <- matrix(c(1, 2, 1, 2), ncol=2) - # Run colocboost expect_warning( result <- colocboost( @@ -145,8 +146,7 @@ test_that("colocboost handles different variant sets", { Y = test_data$Y, dict_YX = dict_YX, M = 5 # Small number of iterations for testing - ), - NA + ) ) # Test that we get a colocboost object diff --git a/tests/testthat/test_inference.R b/tests/testthat/test_inference.R new file mode 100644 index 0000000..8c92f1d --- /dev/null +++ b/tests/testthat/test_inference.R @@ -0,0 +1,101 @@ +library(testthat) + +# Utility function to generate a simple colocboost results +generate_test_result <- function(n = 100, p = 20, L = 2, seed = 42) { + set.seed(seed) + + # Generate X with LD structure + sigma <- matrix(0, p, p) + for (i in 1:p) { + for (j in 1:p) { + sigma[i, j] <- 0.9^abs(i - j) + } + } + X <- MASS::mvrnorm(n, rep(0, p), sigma) + colnames(X) <- paste0("SNP", 1:p) + + # Generate true effects - create a shared causal variant + true_beta <- matrix(0, p, L) + true_beta[5, 1] <- 0.5 # SNP5 affects trait 1 + true_beta[5, 2] <- 0.4 # SNP5 also affects trait 2 (colocalized) + true_beta[10, 2] <- 0.3 # SNP10 only affects trait 2 + + # Generate Y with some noise + Y <- matrix(0, n, L) + for (l in 1:L) { + Y[, l] <- X %*% true_beta[, l] + rnorm(n, 0, 1) + } + + # Convert Y to list + Y_list <- list(Y[,1], Y[,2]) + X_list <- list(X, X) + + # Run colocboost with minimal parameters to get a model object + suppressWarnings({ + result <- colocboost( + X = X_list, + Y = Y_list, + M = 5, # Small number of iterations for faster testing + output_level = 3 # Include full model details + ) + }) + result +} + + +# Test colocboost_plot function +test_that("colocboost_plot handles different plot options", { + skip_on_cran() + + # Generate a test colocboost results + cb_res <- generate_test_result() + + # Basic plot call + expect_error(suppressWarnings(colocboost_plot(cb_res)), NA) + + # Test with different y-axis values + expect_error(suppressWarnings(colocboost_plot(cb_res, y = "z_original")), NA) + + # Test with different outcome_idx + expect_error(suppressWarnings(colocboost_plot(cb_res, outcome_idx = 1)), NA) +}) + +# Test get_cos_summary function +test_that("get_cos_summary handles different parameters", { + skip_on_cran() + + # Generate a test colocboost results + cb_res <- generate_test_result() + + # Basic summary call + expect_error(get_cos_summary(cb_res), NA) + + # With custom outcome names + expect_error(get_cos_summary(cb_res, outcome_names = c("Trait1", "Trait2")), NA) + + # With gene name + summary_with_gene <- get_cos_summary(cb_res, region_name = "TestGene") + + # If summary is not NULL, check for region_name column + if (!is.null(summary_with_gene)) { + expect_true("region_name" %in% colnames(summary_with_gene)) + expect_equal(summary_with_gene$region_name[1], "TestGene") + } +}) + +# Test for get_strong_colocalization +test_that("get_strong_colocalization filters results correctly", { + skip_on_cran() + + # Generate a test colocboost results + cb_res <- generate_test_result() + + # Basic call + expect_error(get_strong_colocalization(cb_res), NA) + + # With stricter thresholds + expect_error(get_strong_colocalization(cb_res, cos_npc_cutoff = 0.8), NA) + + # With p-value threshold + expect_error(get_strong_colocalization(cb_res, pvalue_cutoff = 0.05), NA) +}) \ No newline at end of file diff --git a/tests/testthat/test_model.R b/tests/testthat/test_model.R index 944424e..68420ab 100644 --- a/tests/testthat/test_model.R +++ b/tests/testthat/test_model.R @@ -39,71 +39,24 @@ generate_test_model <- function(n = 100, p = 20, L = 2, seed = 42) { output_level = 3 # Include full model details )$diagnostic_details }) + result$cb_model_para$update_y <- c(1:result$cb_model_para$L) + Y_list <- lapply(Y_list, as.matrix) result$cb_data <- colocboost_init_data( - X = X_list, - Y = Y_list, - dict_YX = NULL, - Z = NULL, - LD = NULL, - N_sumstat = NULL, - dict_sumstatLD = NULL, - Var_y = NULL, - SeBhat = NULL - ) + X = X_list, + Y = Y_list, + dict_YX = c(1,2), + Z = NULL, + LD = NULL, + N_sumstat = NULL, + dict_sumstatLD = NULL, + Var_y = NULL, + SeBhat = NULL, + keep_variables = lapply(X_list, colnames) + ) + class(result) <- "colocboost" result } -# Test for colocboost_check_update_jk -test_that("colocboost_check_update_jk handles update selection", { - skip_on_cran() - - # Generate a test model - cb_obj <- generate_test_model() - - # Check that function can be called without error - # Note: This is testing a function that's normally internal - # and would typically be covered by testing the main function - if (exists("colocboost_check_update_jk")) { - expect_error({ - result <- colocboost_check_update_jk( - cb_obj$cb_model, - cb_obj$cb_model_para, - cb_obj$cb_data - ) - }, NA) - } else { - skip("colocboost_check_update_jk not directly accessible") - } -}) - -# Test for colocboost_update -test_that("colocboost_update updates model parameters", { - skip_on_cran() - - # Generate a test model - cb_obj <- generate_test_model() - - # Access update function if it's exported - if (exists("colocboost_update")) { - # Create a temporary update status to test with - cb_obj$cb_model_para$update_temp <- list( - update_status = rep(1, cb_obj$cb_model_para$L), - real_update_jk = rep(1, cb_obj$cb_model_para$L) - ) - - # Test function - expect_error({ - updated_model <- colocboost_update( - cb_obj$cb_model, - cb_obj$cb_model_para, - cb_obj$cb_data - ) - }, NA) - } else { - skip("colocboost_update not directly accessible") - } -}) - # Test for colocboost_init_data test_that("colocboost_init_data correctly initializes data", { skip_on_cran() @@ -121,14 +74,15 @@ test_that("colocboost_init_data correctly initializes data", { expect_error({ cb_data <- colocboost_init_data( X = list(X), - Y = list(Y[,1], Y[,2]), - dict_YX = matrix(c(1:2, rep(1,2)), ncol=2), + Y = list(Y[,1,drop=F], Y[,2,drop=F]), + dict_YX = c(1,1), Z = NULL, LD = NULL, N_sumstat = NULL, dict_sumstatLD = NULL, Var_y = NULL, - SeBhat = NULL + SeBhat = NULL, + keep_variables = list(colnames(X)), ) }, NA) } else { @@ -160,88 +114,19 @@ test_that("colocboost_assemble processes model results", { test_that("colocboost_workhorse performs boosting iterations", { skip_on_cran() - # Generate test data - set.seed(42) - n <- 50 - p <- 10 - X <- matrix(rnorm(n*p), n, p) - colnames(X) <- paste0("SNP", 1:p) - Y <- matrix(rnorm(n*2), n, 2) - Y_list <- list(Y[,1], Y[,2]) - X_list <- list(X, X) - - # Initialize CB objects - suppressWarnings({ - # First get the data object by running colocboost with M=0 - temp <- colocboost(X = X_list, Y = Y_list, M = 0) - - # If the workhorse function is exported - if (exists("colocboost_workhorse")) { - expect_error({ - result <- colocboost_workhorse( - temp$cb_data, - M = 5 # Small number for testing - ) - }, NA) - } else { - skip("colocboost_workhorse not directly accessible") - } - }) -}) - -# Test colocboost_plot function -test_that("colocboost_plot handles different plot options", { - skip_on_cran() - - # Generate a test model - cb_obj <- generate_test_model() - - # Basic plot call - expect_error(suppressWarnings(colocboost_plot(cb_obj)), NA) - - # Test with different y-axis values - expect_error(suppressWarnings(colocboost_plot(cb_obj, y = "z_original")), NA) - - # Test with different outcome_idx - expect_error(suppressWarnings(colocboost_plot(cb_obj, outcome_idx = 1)), NA) -}) - -# Test get_cos_summary function -test_that("get_cos_summary handles different parameters", { - skip_on_cran() - # Generate a test model cb_obj <- generate_test_model() - # Basic summary call - expect_error(get_cos_summary(cb_obj), NA) - - # With custom outcome names - expect_error(get_cos_summary(cb_obj, outcome_names = c("Trait1", "Trait2")), NA) - - # With gene name - summary_with_gene <- get_cos_summary(cb_obj, gene_name = "TestGene") - - # If summary is not NULL, check for gene_name column - if (!is.null(summary_with_gene)) { - expect_true("gene_name" %in% colnames(summary_with_gene)) - expect_equal(summary_with_gene$gene_name[1], "TestGene") + # If the workhorse function is exported + if (exists("colocboost_workhorse")) { + expect_error({ + result <- colocboost_workhorse( + cb_obj$cb_data, + M = 5 # Small number for testing + ) + }, NA) + } else { + skip("colocboost_workhorse not directly accessible") } -}) -# Test for get_strong_colocalization -test_that("get_strong_colocalization filters results correctly", { - skip_on_cran() - - # Generate a test model - cb_obj <- generate_test_model() - - # Basic call - expect_error(get_strong_colocalization(cb_obj), NA) - - # With stricter thresholds - expect_error(get_strong_colocalization(cb_obj, cos_npc_cutoff = 0.8), NA) - - # With p-value threshold - expect_error(get_strong_colocalization(cb_obj, pvalue_cutoff = 0.05), NA) -}) \ No newline at end of file +})