Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 22 additions & 40 deletions R/colocboost_check_update_jk.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,22 @@
#'
#' @return update_status and real_update_jk for each trait
#' @noRd
colocboost_check_update_jk <- function(cb_model, cb_model_para, cb_data,
prioritize_jkstar = TRUE,
jk_equiv_corr = 0.8, ##### more than 2 traits
jk_equiv_loglik = 1, ## more than 2 traits
func_compare = "min_max", ##### more than 3 traits
coloc_thresh = 0.1) {
colocboost_check_update_jk <- function(cb_model, cb_model_para, cb_data) {

pos.update <- which(cb_model_para$update_y == 1)
focal_outcome_idx <- cb_model_para$focal_outcome_idx
if (is.null(focal_outcome_idx)) {
cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data,
prioritize_jkstar = prioritize_jkstar,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik,
func_compare = func_compare,
coloc_thresh = coloc_thresh
)
cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data)
} else {
if (focal_outcome_idx %in% pos.update) {
cb_model_para <- boost_check_update_jk_focal(cb_model, cb_model_para, cb_data,
prioritize_jkstar = prioritize_jkstar,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik,
func_compare = func_compare,
coloc_thresh = coloc_thresh,
cb_model_para <- boost_check_update_jk_focal(
cb_model,
cb_model_para,
cb_data,
focal_outcome_idx = focal_outcome_idx
)
} else {
cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data,
prioritize_jkstar = prioritize_jkstar,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik,
func_compare = func_compare,
coloc_thresh = coloc_thresh
)
cb_model_para <- boost_check_update_jk_nofocal(cb_model, cb_model_para, cb_data)
}
}
return(cb_model_para)
Expand All @@ -47,12 +29,8 @@ colocboost_check_update_jk <- function(cb_model, cb_model_para, cb_data,


#' @importFrom stats median
boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data,
prioritize_jkstar = TRUE,
jk_equiv_corr = 0.8, ##### more than 2 traits
jk_equiv_loglik = 1, ## more than 2 traits
func_compare = "min_max", ##### more than 3 traits
coloc_thresh = 0.1) {
boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data) {

############# Output #################
# we will obtain and update the cb_model_para$update_status and cb_model_para$real_update_jk
######################################
Expand All @@ -61,9 +39,11 @@ boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data,
update_status <- rep(0, cb_model_para$L)
update_jk <- rep(NA, cb_model_para$L + 1)
real_update_jk <- rep(NA, cb_model_para$L)
if (is.null(cb_model_para$coloc_thresh)) {
cb_model_para$coloc_thresh <- (1 - coloc_thresh) * max(sapply(1:length(cb_model), function(i) max(cb_model[[i]]$change_loglike)))
}
# - initial parameter
prioritize_jkstar <- cb_model_para$prioritize_jkstar
jk_equiv_corr <- cb_model_para$jk_equiv_corr
jk_equiv_loglik <- cb_model_para$jk_equiv_loglik
func_compare <- cb_model_para$func_compare

# - update only Ys which is not stop
pos.update <- which(cb_model_para$update_y == 1)
Expand Down Expand Up @@ -328,12 +308,8 @@ boost_check_update_jk_nofocal <- function(cb_model, cb_model_para, cb_data,


boost_check_update_jk_focal <- function(cb_model, cb_model_para, cb_data,
prioritize_jkstar = TRUE,
jk_equiv_corr = 0.8, ##### more than 2 traits
jk_equiv_loglik = 1, ## more than 2 traits
func_compare = "min_max", ##### more than 3 traits
coloc_thresh = 0.1,
focal_outcome_idx = 1) {

############# Output #################
# we will obtain and update the cb_model_para$update_status and cb_model_para$real_update_jk
######################################
Expand All @@ -345,6 +321,12 @@ boost_check_update_jk_focal <- function(cb_model, cb_model_para, cb_data,
if (is.null(cb_model_para$coloc_thresh)) {
cb_model_para$coloc_thresh <- (1 - coloc_thresh) * max(sapply(1:length(cb_model), function(i) max(cb_model[[i]]$change_loglike)))
}
# - initial parameter
prioritize_jkstar <- cb_model_para$prioritize_jkstar
jk_equiv_corr <- cb_model_para$jk_equiv_corr
jk_equiv_loglik <- cb_model_para$jk_equiv_loglik
func_compare <- cb_model_para$func_compare
coloc_thresh <- cb_model_para$coloc_thresh

# - update only Ys which is not stop
pos.update <- which(cb_model_para$update_y == 1)
Expand Down
24 changes: 20 additions & 4 deletions R/colocboost_init.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ colocboost_inits <- function() {
colocboost_init_data <- function(X, Y, dict_YX,
Z, LD, N_sumstat, dict_sumstatLD,
Var_y, SeBhat,
keep_variables = NULL,
keep_variables,
focal_outcome_idx = NULL,
focal_outcome_variables = TRUE,
overlap_variables = FALSE,
Expand Down Expand Up @@ -213,12 +213,20 @@ colocboost_init_model <- function(cb_data,
#' @importFrom utils tail
colocboost_init_para <- function(cb_data, cb_model, tau = 0.01,
func_simplex = "z2z",
lambda = 0.5, lambda_focal_outcome = 1,
lambda = 0.5,
lambda_focal_outcome = 1,
learning_rate_decay = 1,
multi_test_thresh = 1,
func_multi_test = "lfdr",
LD_free = FALSE,
outcome_names = NULL,
focal_outcome_idx = NULL) {
focal_outcome_idx = NULL,
dynamic_learning_rate = TRUE,
prioritize_jkstar = TRUE,
jk_equiv_corr = 0.8,
jk_equiv_loglik = 1,
func_compare = "min_max",
coloc_thresh = 0.1) {
################# initialization #######################################
# - sample size
N <- sapply(cb_data$data, function(dt) dt$N)
Expand Down Expand Up @@ -251,7 +259,8 @@ colocboost_init_para <- function(cb_data, cb_model, tau = 0.01,
} else {
outcome_names <- paste0("Y", 1:L)
}

# - initial coloc_thresh
coloc_thresh <- (1 - coloc_thresh) * max(sapply(1:length(cb_model), function(i) max(cb_model[[i]]$change_loglike)))
cb_model_para <- list(
"L" = L,
"P" = P,
Expand All @@ -260,7 +269,14 @@ colocboost_init_para <- function(cb_data, cb_model, tau = 0.01,
"func_simplex" = func_simplex,
"lambda" = lambda,
"lambda_focal_outcome" = lambda_focal_outcome,
"learning_rate_decay" = learning_rate_decay,
"profile_loglike" = profile_loglike,
"dynamic_learning_rate" = dynamic_learning_rate,
"prioritize_jkstar" = prioritize_jkstar,
"jk_equiv_corr" = jk_equiv_corr,
"jk_equiv_loglik" = jk_equiv_loglik,
"func_compare" = func_compare,
"coloc_thresh" = coloc_thresh,
"update_status" = c(),
"jk" = c(),
"update_y" = update_y,
Expand Down
115 changes: 36 additions & 79 deletions R/colocboost_one_causal.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,13 @@
#' @rdname colocboost_one_causal
#' @keywords cb_one_causal
#' @noRd
colocboost_one_causal <- function(cb_model, cb_model_para, cb_data,
jk_equiv_corr = 0.8,
jk_equiv_loglik = 1,
tau = 0.01,
learning_rate_decay = 1,
func_simplex = "z2z",
lambda = 0.5,
lambda_focal_outcome = 1,
LD_free = FALSE) {
if (jk_equiv_corr != 0) {
cb_obj <- colocboost_one_iteration(cb_model, cb_model_para, cb_data,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik,
tau = tau,
learning_rate_decay = learning_rate_decay,
func_simplex = func_simplex,
lambda = lambda,
lambda_focal_outcome = lambda_focal_outcome,
LD_free = LD_free
)
colocboost_one_causal <- function(cb_model, cb_model_para, cb_data) {

if (cb_model_para$jk_equiv_corr != 0) {
cb_obj <- colocboost_one_iteration(cb_model, cb_model_para, cb_data)

} else {
cb_obj <- colocboost_diagLD(cb_model, cb_model_para, cb_data,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik,
tau = tau,
learning_rate_decay = learning_rate_decay,
func_simplex = func_simplex,
lambda = lambda,
lambda_focal_outcome = lambda_focal_outcome,
LD_free = LD_free
)
cb_obj <- colocboost_diagLD(cb_model, cb_model_para, cb_data)
}
return(cb_obj)
}
Expand All @@ -64,21 +40,15 @@ colocboost_one_causal <- function(cb_model, cb_model_para, cb_data,
#'
#' @keywords cb_one_causal
#' @noRd
colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data,
jk_equiv_corr = 0.8,
jk_equiv_loglik = 1,
tau = 0.01,
learning_rate_decay = 1,
func_simplex = "z2z",
lambda = 0.5,
lambda_focal_outcome = 1,
LD_free = FALSE) {
colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data) {

if (sum(cb_model_para$update_y == 1) != 0) {
######## - some traits updated
# - step 1: check update clusters
real_update <- boost_check_update_jk_one_causal(cb_model, cb_model_para, cb_data,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik
real_update <- boost_check_update_jk_one_causal(
cb_model,
cb_model_para,
cb_data
)

# - step 2: boost update
Expand All @@ -89,13 +59,10 @@ colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data,
cb_model_para$update_status <- cbind(cb_model_para$update_status, as.matrix(real_update[[i_update]]$update_status))
cb_model_para$real_update_jk <- rbind(cb_model_para$real_update_jk, real_update[[i_update]]$real_update_jk)
# - update cb_model
cb_model <- colocboost_update(cb_model, cb_model_para, cb_data,
tau = tau,
learning_rate_decay = learning_rate_decay,
func_simplex = func_simplex,
lambda = lambda,
lambda_focal_outcome = lambda_focal_outcome,
LD_free = LD_free
cb_model <- colocboost_update(
cb_model,
cb_model_para,
cb_data
)
}
}
Expand Down Expand Up @@ -130,10 +97,8 @@ colocboost_one_iteration <- function(cb_model, cb_model_para, cb_data,
#'
#' @keywords cb_one_causal
#' @noRd
boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data,
prioritize_jkstar = TRUE,
jk_equiv_corr = 0.8,
jk_equiv_loglik = 1) {
boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data) {

pos.update <- which(cb_model_para$update_y == 1)
update_jk <- rep(NA, cb_model_para$L + 1)
# - check pairwise equivalent of jk
Expand Down Expand Up @@ -169,8 +134,8 @@ boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data,
model_update = model_update,
cb_data = cb_data,
X_dict = X_dict,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik
jk_equiv_corr = cb_model_para$jk_equiv_corr,
jk_equiv_loglik = cb_model_para$jk_equiv_loglik
)
# define category with same jk
temp <- sapply(1:nrow(change_each_pair), function(x) {
Expand Down Expand Up @@ -214,15 +179,8 @@ boost_check_update_jk_one_causal <- function(cb_model, cb_model_para, cb_data,
#'
#' @keywords cb_one_causal
#' @noRd
colocboost_diagLD <- function(cb_model, cb_model_para, cb_data,
jk_equiv_corr = 0,
jk_equiv_loglik = 0.1,
tau = 0.01,
learning_rate_decay = 1,
func_simplex = "z2z",
lambda = 0.5,
lambda_focal_outcome = 1,
LD_free = FALSE) {
colocboost_diagLD <- function(cb_model, cb_model_para, cb_data) {

if (sum(cb_model_para$update_y == 1) == 1) {
pos.update <- which(cb_model_para$update_y == 1)
update_jk <- rep(NA, cb_model_para$L + 1)
Expand All @@ -242,9 +200,10 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data,
cb_model_para$jk <- rbind(cb_model_para$jk, update_jk)
cb_model_para$update_status <- cbind(cb_model_para$update_status, as.matrix(update_status))
cb_model_para$real_update_jk <- rbind(cb_model_para$real_update_jk, real_update_jk)
cb_model <- colocboost_update(cb_model, cb_model_para, cb_data,
tau = tau, learning_rate_decay = learning_rate_decay, func_simplex = func_simplex,
lambda = lambda, lambda_focal_outcome = lambda_focal_outcome, LD_free = LD_free
cb_model <- colocboost_update(
cb_model,
cb_model_para,
cb_data
)
}

Expand Down Expand Up @@ -273,9 +232,10 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data,
"real_update_jk" = real_update_jk
)
# - update cb_model
cb_model_tmp <- colocboost_update(cb_model_tmp, cb_model_para, cb_data,
tau = tau, learning_rate_decay = learning_rate_decay, func_simplex = func_simplex,
lambda = lambda, lambda_focal_outcome = lambda_focal_outcome, LD_free = LD_free
cb_model_tmp <- colocboost_update(
cb_model_tmp,
cb_model_para,
cb_data
)
weights <- rbind(weights, cb_model_tmp[[iy]]$weights_path)
}
Expand All @@ -286,8 +246,8 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data,
model_update = model_update,
cb_data = cb_data,
X_dict = X_dict,
jk_equiv_corr = jk_equiv_corr,
jk_equiv_loglik = jk_equiv_loglik
jk_equiv_corr = cb_model_para$jk_equiv_corr,
jk_equiv_loglik = cb_model_para$jk_equiv_loglik
)
change_each_pair <- change_each_pair * overlap_pair
# define category with same jk
Expand Down Expand Up @@ -327,13 +287,10 @@ colocboost_diagLD <- function(cb_model, cb_model_para, cb_data,
cb_model_para$update_status <- cbind(cb_model_para$update_status, as.matrix(real_update[[i_update]]$update_status))
cb_model_para$real_update_jk <- rbind(cb_model_para$real_update_jk, real_update[[i_update]]$real_update_jk)
# - update cb_model
cb_model <- colocboost_update(cb_model, cb_model_para, cb_data,
tau = tau,
learning_rate_decay = learning_rate_decay,
func_simplex = func_simplex,
lambda = lambda,
lambda_focal_outcome = lambda_focal_outcome,
LD_free = LD_free
cb_model <- colocboost_update(
cb_model,
cb_model_para,
cb_data
)
}
}
Expand Down
Loading