Skip to content

Commit 577d78a

Browse files
authored
fit outcome model method (#156)
* fitter and test * another test * add pkgdown entries
1 parent 176b436 commit 577d78a

11 files changed

+196
-5
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ repos:
1818
- parglm
1919
- sandwich
2020
- duckdb
21+
- lmtest
2122
# codemeta must be above use-tidy-description when both are used
2223
# - id: codemeta-description-updated
2324
- id: use-tidy-description

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Imports:
2929
DBI,
3030
duckdb,
3131
formula.tools,
32+
lmtest,
3233
methods,
3334
mvtnorm,
3435
parglm,

R/generics.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,14 @@ setGeneric("read_expanded_data", function(object, period = NULL) standardGeneric
174174
#' @param object A `te_outcome_fitter` object
175175
#' @param data `data.frame` containing outcomes and covariates as defined in `formula`.
176176
#' @param formula `formula` describing the model.
177+
#' @param weights `numeric` vector of weights.
177178
#'
178179
#' @return An object of class `te_outcome_fitted`
179180
#' @export
180181
#' @keywords internal
181182
#' @examples
182-
setGeneric("fit_outcome_model", function(object, data, formula) standardGeneric("fit_outcome_model"))
183+
#' fit_outcome_model
184+
setGeneric("fit_outcome_model", function(object, data, formula, weights = NULL) standardGeneric("fit_outcome_model"))
183185

184186
#' Method for fitting weight models
185187
#'

R/te_model_fitter.R

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#' @include generics.R
2+
NULL
3+
14
#' Outcome Model Fitter Class
25
#'
36
#' This is a virtual class which other outcome model fitter classes should inherit from. Objects of these class exist to
@@ -65,3 +68,51 @@ setMethod(
6568
)
6669
}
6770
)
71+
72+
#' @rdname fit_outcome_model
73+
setMethod(
74+
f = "fit_outcome_model",
75+
signature = "te_stats_glm_logit",
76+
function(object, data, formula, weights) {
77+
data$weights <- if (is.null(weights)) rep(1, nrow(data)) else weights
78+
model <- glm(
79+
formula = formula,
80+
data = data,
81+
family = binomial("logit"),
82+
x = FALSE,
83+
y = FALSE,
84+
weights = weights
85+
)
86+
if (!is.na(object@save_path)) {
87+
if (!dir.exists(object@save_path)) dir.create(object@save_path, recursive = TRUE)
88+
file <- tempfile(pattern = "model_", tmpdir = object@save_path, fileext = ".rds")
89+
saveRDS(model, file = file)
90+
save_path <- data.frame(save = file)
91+
}
92+
93+
vcov <- sandwich::vcovCL(
94+
model,
95+
cluster = data[["id"]],
96+
type = NULL,
97+
sandwich = TRUE,
98+
fix = FALSE
99+
)
100+
101+
model_list <- list(
102+
model = model,
103+
vcov = vcov
104+
)
105+
106+
coef_obj <- lmtest::coeftest(model, vcov. = vcov, save = TRUE)
107+
summary_list <- list()
108+
summary_list[["tidy"]] <- broom::tidy(coef_obj, conf.int = TRUE)
109+
summary_list[["glance"]] <- broom::glance(coef_obj)
110+
if (!is.na(object@save_path)) summary_list[["save_path"]] <- save_path
111+
112+
new(
113+
"te_stats_glm_logit_outcome_fitted",
114+
model = model_list,
115+
summary = summary_list
116+
)
117+
}
118+
)

R/te_outcome_model.R

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,35 @@
1+
#' Fitted Outcome Model Object
2+
#'
3+
#' @slot model list containing fitted model objects.
4+
#' @slot summary list of data.frames. Tidy model summaries a la `broom()` and `glance()`
5+
setClass(
6+
"te_outcome_fitted",
7+
slots = c(
8+
model = "list",
9+
summary = "list"
10+
)
11+
)
12+
13+
#' Fitted Outcome Model Object
14+
#'
15+
#' @slot formula `formula` object for the model fitting
16+
#' @slot treatment_var character. The treatment variable
17+
#' @slot adjustment_vars character. Adjustment variables
18+
#' @slot model_fitter Model fitter object
19+
#' @slot fitted list. Saves the model objects
120
setClass("te_outcome_model",
221
slots = c(
322
formula = "formula",
423
treatment_var = "character",
524
adjustment_vars = "character",
6-
model_fitter = "te_model_fitter"
25+
model_fitter = "te_model_fitter",
26+
fitted = "te_outcome_fitted"
727
)
828
)
29+
30+
# te_stats_glm_logit_outcome_fitted -----
31+
32+
setClass(
33+
"te_stats_glm_logit_outcome_fitted",
34+
contains = "te_outcome_fitted"
35+
)

R/te_weights.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#' Fitted Weights Object
44
#'
5-
#' @slot specification list. The parameters specifying how the model should be fit
5+
#' @slot label string. A short description of the model
66
#' @slot summary list of data.frames. Tidy model summaries a la `broom()` and `glance()`
77
#' @slot fitted list. Saves the model objects or at least summaries if large
88
setClass("te_weights_fitted",

_pkgdown.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ reference:
4444
- te_data-class
4545
- te_datastore-class
4646
- te_model_fitter-class
47+
- te_outcome_fitted-class
48+
- te_outcome_model-class
4749
- title: Internal Methods
4850
- contents:
4951
- save_expanded_data

man/fit_outcome_model.Rd

Lines changed: 10 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/te_outcome_fitted-class.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/te_outcome_model-class.Rd

Lines changed: 23 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-te_model_fitter.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,62 @@ test_that("fit_weights_model works for stats_glm_logit", {
2727
saved_model <- readRDS(result@summary$save_path$path)
2828
expect_class(saved_model, "glm")
2929
})
30+
31+
32+
test_that("fit_outcome_model works for stats_glm_logit", {
33+
object <- stats_glm_logit(NA)
34+
result <- fit_outcome_model(
35+
object,
36+
data = vignette_switch_data,
37+
formula = outcome ~ assigned_treatment + followup_time + nvarC,
38+
weights = vignette_switch_data$weight
39+
)
40+
41+
expect_class(result, "te_outcome_fitted")
42+
expect_class(result, "te_stats_glm_logit_outcome_fitted")
43+
44+
expect_equal(
45+
result@summary[["tidy"]]$estimate,
46+
c(-3.23428770614918, 0.0429082229813285, -0.000849656189863374, -0.0370118224213905)
47+
)
48+
expect_equal(
49+
as.data.frame(result@summary[["tidy"]][2, c("conf.low", "conf.high")]),
50+
data.frame(conf.low = -0.501620572971525, conf.high = 0.587437018934182)
51+
)
52+
expect_equal(result@summary[["glance"]]$df.null, 1939052)
53+
54+
expect_matrix(result@model$vcov, nrows = 4, ncols = 4)
55+
expect_equal(
56+
diag(result@model$vcov),
57+
c(
58+
`(Intercept)` = 0.145504452458149, assigned_treatment = 0.0771872414783003,
59+
followup_time = 1.59771524178973e-06, nvarC = 4.40825112262011e-05
60+
)
61+
)
62+
})
63+
64+
test_that("fit_outcome_model works for stats_glm_logit with save_dir", {
65+
dir <- withr::local_tempdir(pattern = "glm_test")
66+
object <- stats_glm_logit(save_path = dir)
67+
data <- data.frame(
68+
y = rep(c(1, 0, 1, 0), times = c(15, 5, 5, 15)),
69+
x = rep(c(1, 0), times = c(20, 20))
70+
)
71+
result <- fit_outcome_model(
72+
object,
73+
data = data,
74+
formula = y ~ x
75+
)
76+
77+
expect_class(result, "te_outcome_fitted")
78+
expect_class(result, "te_stats_glm_logit_outcome_fitted")
79+
80+
expect_equal(result@summary[["tidy"]]$estimate, c(-1.0986123, 2.1972246))
81+
fitted_model <- readRDS(result@summary$save_path$save)
82+
expect_equal(
83+
fitted_model,
84+
result@model$model,
85+
ignore_function_env = TRUE,
86+
ignore_formula_env = TRUE
87+
)
88+
})

0 commit comments

Comments
 (0)