Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.20.3"
version = "0.21.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
Expand Down
52 changes: 32 additions & 20 deletions src/counterfactual_mean_based/clever_covariate.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@

"""
data_adaptive_ps_lower_bound(Ψ::StatisticalCMCompositeEstimand)
data_adaptive_ps_lower_bound(n::Int; max_lb=0.1)

This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false)
but the study does not show strictly better behaviour of the strategy so not a default for now.
Data-adaptive propensity score truncation level from Gruber et al. (2022):
"Data-Adaptive Selection of the Propensity Score Truncation Level for
Inverse-Probability–Weighted and Targeted Maximum Likelihood Estimators
of Marginal Point Treatment Effects" (doi:10.1093/aje/kwac087).

This is the default when `ps_lowerbound=nothing`. Here a maximum lower bound
is applied to prevent extreme truncation in small samples.
"""
data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) =
min(5 / (√(n)*log(n/5)), max_lb)
min(5 / (√(n)*log(n)), max_lb)

ps_lower_bound(n::Int, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(n; max_lb=max_lb)
ps_lower_bound(n::Int, lower_bound; max_lb=0.1) = min(max_lb, lower_bound)
Expand All @@ -18,12 +23,14 @@ function truncate!(v::AbstractVector, ps_lowerbound::AbstractFloat)
end
end

function balancing_weights(G, dataset; ps_lowerbound=1e-8)
jointlikelihood = ones(nrows(dataset))
function balancing_weights(G, dataset; ps_lowerbound=nothing)
n = nrows(dataset)
jointlikelihood = ones(n)
for Gᵢ ∈ G.components
jointlikelihood .*= likelihood(Gᵢ, dataset)
end
truncate!(jointlikelihood, ps_lowerbound)
actual_lowerbound = ps_lower_bound(n, ps_lowerbound)
truncate!(jointlikelihood, actual_lowerbound)
return 1. ./ jointlikelihood
end

Expand All @@ -32,39 +39,44 @@ end
Ψ::StatisticalCMCompositeEstimand,
Gs::Tuple{Vararg{ConditionalDistributionEstimate}},
dataset;
ps_lowerbound=1e-8,
ps_lowerbound=nothing,
weighted_fluctuation=false
)

Computes the clever covariate and weights that are used to fluctuate the initial Q.
Computes the clever covariate matrix, weights, and signs used to fluctuate the initial Q.

Returns a tuple `(H, w, signs)` where:
- `H` is an `n × K` matrix, one column per counterfactual in `indicator_fns(Ψ)`.
- `w` is a length-`n` weight vector.
- `signs` is a length-`K` vector of signs from `indicator_fns(Ψ)`.

if `weighted_fluctuation = false`:

- ``clever_covariate(t, w) = \\frac{SpecialIndicator(t)}{p(t|w)}``
- ``weight(t, w) = 1``
- ``H_{k}(t, w) = \\frac{I_k(t)}{p(t|w)}``
- ``w(t, w) = 1``

if `weighted_fluctuation = true`:

- ``clever_covariate(t, w) = SpecialIndicator(t)``
- ``weight(t, w) = \\frac{1}{p(t|w)}``
- ``H_{k}(t, w) = I_k(t)``
- ``w(t, w) = \\frac{1}{p(t|w)}``

where SpecialIndicator(t) is defined in `indicator_fns`.
where ``I_k(t)`` is the unsigned indicator for the k-th counterfactual.
"""
function clever_covariate_and_weights(
Ψ::StatisticalCMCompositeEstimand,
G,
dataset;
ps_lowerbound=1e-8,
ps_lowerbound=nothing,
weighted_fluctuation=false
)
# Compute the indicator values
# Compute the indicator matrix (n×K) and signs (K,)
T = selectcols(dataset, (p.estimand.outcome for p in G.components))
indic_vals = indicator_values(indicator_fns(Ψ), T)
indic_mat, signs = indicator_matrix(indicator_fns(Ψ), T)
weights = balancing_weights(G, dataset; ps_lowerbound=ps_lowerbound)
if weighted_fluctuation
return indic_vals, weights
return indic_mat, weights, signs
end
# Vanilla unweighted fluctuation
indic_vals .*= weights
return indic_vals, ones(size(weights, 1))
indic_mat .*= weights
return indic_mat, ones(size(weights, 1)), signs
end
60 changes: 41 additions & 19 deletions src/counterfactual_mean_based/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mutable struct Tmle <: Estimator
end

"""
Tmle(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing, machine_cache=false)
Tmle(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, weighted=false, tol=nothing, machine_cache=false)

Defines a TMLE estimator using the specified models for estimation of the nuisance parameters. The estimator is a
function that can be applied to estimate estimands for a dataset.
Expand All @@ -52,8 +52,9 @@ function that can be applied to estimate estimands for a dataset.
- collaborative_strategy (default: nothing): A collaborative strategy to use for the estimation. Then the resampling strategy is used to evaluate the candidates.
- resampling (default: `default_resampling(collaborative_strategy)`): Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla TMLE while
any valid `MLJ.ResamplingStrategy` will result in CV-TMLE.
- ps_lowerbound (default: 1e-8): Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will
result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/).
- ps_lowerbound (default: nothing): Lowerbound for the propensity score to avoid division by 0. The default `nothing`
uses data-adaptive truncation as described in [Gruber et al. (2022)](https://pubmed.ncbi.nlm.nih.gov/35512316/): `5/(√n * log(n/5))`.
A fixed value can also be provided.
- weighted (default: false): Whether the fluctuation model is a classig GLM or a weighted version. The weighted fluctuation has
been show to be more robust to positivity violation in practice.
- tol (default: nothing): Convergence threshold for the TMLE algorithm iterations. If nothing (default), 1/(sample size) will be used. See also `max_iter`.
Expand Down Expand Up @@ -81,7 +82,7 @@ function Tmle(;
models=default_models(),
collaborative_strategy=nothing,
resampling=default_resampling(collaborative_strategy),
ps_lowerbound=1e-8,
ps_lowerbound=nothing,
weighted=true,
tol=nothing,
max_iter=1,
Expand All @@ -103,6 +104,8 @@ end
function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1())
# Check if the inputs are suitable for the specified estimand
check_inputs(Ψ, dataset, tmle.prevalence)
# Reset collaborative strategy state before building relevant factors
tmle.collaborative_strategy !== nothing && initialise!(tmle.collaborative_strategy, Ψ)
# Make train-validation pairs
train_validation_indices = get_train_validation_indices(tmle.resampling, Ψ, dataset)
# Initial fit of the SCM's relevant factors
Expand Down Expand Up @@ -167,7 +170,7 @@ function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(),
return TMLEstimate(Ψ, Ψ̂, σ̂, n, IC), cache
end

gradient_and_estimate(::Tmle, Ψ, factors, dataset; ps_lowerbound=1e-8) =
gradient_and_estimate(::Tmle, Ψ, factors, dataset; ps_lowerbound=nothing) =
gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound)

#####################################################################
Expand All @@ -179,10 +182,11 @@ mutable struct Ose <: Estimator
resampling::Union{Nothing, ResamplingStrategy}
ps_lowerbound::Union{Float64, Nothing}
machine_cache::Bool
prevalence::Union{Nothing, Float64}
end

"""
Ose(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false)
Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false, prevalence=nothing)

Defines a One Step Estimator using the specified models for estimation of the nuisance parameters. The estimator is a
function that can be applied to estimate estimands for a dataset.
Expand All @@ -192,9 +196,11 @@ function that can be applied to estimate estimands for a dataset.
- models: A Dict(variable => model, ...) where the `variables` are the outcome variables modeled by the `models`.
- resampling: Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla estimation while
any valid `MLJ.ResamplingStrategy` will result in CV-OSE.
- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will
result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/).
- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` (default) will
result in a data adaptive definition as described in [Gruber et al. (2022)](https://pubmed.ncbi.nlm.nih.gov/35512316/).
- machine_cache: Whether MLJ.machine created during estimation should cache data.
- prevalence: The prevalence of the outcome in the population. If provided, the estimator will use case-control
weighted estimation (CCW-OSE) to correct for biased sampling.

# Run Argument

Expand All @@ -213,22 +219,32 @@ ose = Ose()
Ψ̂ₙ, cache = ose(Ψ, dataset)
```
"""
Ose(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) =
Ose(models, resampling, ps_lowerbound, machine_cache)
Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false, prevalence=nothing) =
Ose(models, resampling, ps_lowerbound, machine_cache, prevalence)

function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1())
# Check the estimand against the dataset
check_treatment_levels(Ψ, dataset)
check_inputs(Ψ, dataset, ose.prevalence)
# Make train-validation pairs
train_validation_indices = get_train_validation_indices(ose.resampling, Ψ, dataset)
# Initial fit of the SCM's relevant factors
initial_factors = get_relevant_factors(Ψ)
nomissing_dataset = nomissing(dataset, variables(initial_factors))
initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset;
# Get appropriate dataset (matched controls if prevalence is set)
fluctuation_dataset = get_fluctuation_dataset(dataset, initial_factors;
prevalence=ose.prevalence,
verbosity=verbosity
)
initial_factors_dataset = choose_initial_dataset(dataset, fluctuation_dataset;
train_validation_indices=train_validation_indices,
prevalence=nothing
prevalence=ose.prevalence
)
# Compute prevalence weights for fitting Q
prevalence_weights = compute_prevalence_weights(ose.prevalence, initial_factors_dataset[!, initial_factors.outcome_mean.outcome])
initial_factors_estimator = CMRelevantFactorsEstimator(;
models=ose.models,
train_validation_indices=train_validation_indices,
prevalence_weights=prevalence_weights
)
initial_factors_estimator = CMRelevantFactorsEstimator(;models=ose.models, train_validation_indices=train_validation_indices)
initial_factors_estimate = initial_factors_estimator(
initial_factors,
initial_factors_dataset;
Expand All @@ -237,19 +253,25 @@ function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), v
acceleration=acceleration
)
# Get propensity score truncation threshold
n = nrows(nomissing_dataset)
n = nrows(fluctuation_dataset)
ps_lowerbound = ps_lower_bound(n, ose.ps_lowerbound)

# Gradient and estimate
IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, fluctuation_dataset;
ps_lowerbound=ps_lowerbound, prevalence_weights=prevalence_weights)
σ̂ = std(IC)
n = size(IC, 1)
verbosity >= 1 && @info "Done."
return OSEstimate(Ψ, Ψ̂, σ̂, n, IC), cache
end

function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=1e-8)
IC, Ψ̂ = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound)
function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=nothing, prevalence_weights=nothing)
Q = factors.outcome_mean
G = factors.propensity_score
ctf_agg = counterfactual_aggregate(Ψ, Q, dataset)
gradient_Y_X = ∇YX(Ψ, Q, G, dataset; ps_lowerbound=ps_lowerbound)
y = float(dataset[!, Q.estimand.outcome])
IC, Ψ̂ = gradient_and_estimate(ctf_agg, gradient_Y_X, y, prevalence_weights)
IC_mean = mean(IC)
IC .-= IC_mean
return IC, Ψ̂ + IC_mean
Expand Down
39 changes: 28 additions & 11 deletions src/counterfactual_mean_based/fluctuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ mutable struct Fluctuation <: MLJBase.Supervised
initial_factors::MLCMRelevantFactors
tol::Union{Nothing, Float64}
max_iter::Int
ps_lowerbound::Float64
ps_lowerbound::Union{Float64, Nothing}
weighted::Bool
cache::Bool
prevalence_weights::Union{Nothing, Vector{Float64}}
end

Fluctuation(Ψ, initial_factors; tol=nothing, max_iter=1, ps_lowerbound=1e-8, weighted=false, cache=false, prevalence_weights=nothing) =
Fluctuation(Ψ, initial_factors; tol=nothing, max_iter=1, ps_lowerbound=nothing, weighted=false, cache=false, prevalence_weights=nothing) =
Fluctuation(Ψ, initial_factors, tol, max_iter, ps_lowerbound, weighted, cache, prevalence_weights)

one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:MLJBase.Continuous} = LinearRegressor(fit_intercept=false, offsetcol = :offset)
Expand All @@ -24,11 +24,23 @@ The GLM models require inputs of the same type, which sometimes is not the case
same_type_df(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} =
DataFrame(covariate=covariate, offset=convert(Vector{T1}, offset))

function fluctuation_input(covariate, ŷ)
offset = compute_offset()
function fluctuation_input(covariate::AbstractVector, ŷ)
offset = compute_offset(ŷ)
return same_type_df(covariate, offset)
end

function fluctuation_input(covariate::AbstractMatrix, ŷ)
offset = compute_offset(ŷ)
T = eltype(covariate)
typed_offset = eltype(offset) == T ? offset : convert(Vector{T}, offset)
K = size(covariate, 2)
df = DataFrame(:offset => typed_offset)
for k in 1:K
df[!, Symbol("H_", k)] = covariate[:, k]
end
return df
end

hasconverged(gradient, tol) = abs(mean(gradient)) < tol

"""
Expand All @@ -53,13 +65,13 @@ If prevalence weights are provided, they are applied to the weights and normaliz
function initialize_observed_cache(model, X, y)
Q⁰ = model.initial_factors.outcome_mean
G⁰ = model.initial_factors.propensity_score
H, w = clever_covariate_and_weights(
H, w, signs = clever_covariate_and_weights(
model.Ψ, G⁰, X;
ps_lowerbound=model.ps_lowerbound,
weighted_fluctuation=model.weighted
)
= MLJBase.predict(Q⁰, X)
return Dict{Symbol, Any}(:H => H, :w => w, : => , :y => float(y))
ŷ = MLJBase.predict(Q⁰, X)
return Dict{Symbol, Any}(:H => H, :w => w, :signs => signs, :ŷ => ŷ, :y => float(y))
end

"""
Expand All @@ -84,7 +96,7 @@ function initialize_counterfactual_cache(model, X)
T_ct = counterfactualTreatment(vals, Ttemplate)
X_ct = DataFrame((;(Symbol(colname) => colname ∈ names(T_ct) ? T_ct[!, colname] : X[!, colname] for colname in names(X))...))

covariates_ct, w_ct = clever_covariate_and_weights(Ψ,
covariates_ct, _, _ = clever_covariate_and_weights(Ψ,
G⁰,
X_ct;
ps_lowerbound=model.ps_lowerbound,
Expand Down Expand Up @@ -137,7 +149,8 @@ function compute_gradient_and_estimate_from_caches!(
# Compute gradient
Ey = expected_value(observed_cache[:ŷ])
ct_aggregate = compute_counterfactual_aggregate!(counterfactual_cache, Q)
gradient_Y_X = ∇YX(observed_cache[:H], observed_cache[:y], Ey, observed_cache[:w])
H_combined = observed_cache[:H] * observed_cache[:signs]
gradient_Y_X = ∇YX(H_combined, observed_cache[:y], Ey, observed_cache[:w])
gradient, Ψ̂ = gradient_and_estimate(ct_aggregate, gradient_Y_X, observed_cache[:y], prevalence_weights)
return gradient, Ψ̂
end
Expand Down Expand Up @@ -219,6 +232,7 @@ function gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights)
ctl_sum = q̄₀_over_J * sum(ct_aggregate_controls_batch)
gradient[case_id] = q₀ * (gradient_Y_X_cases[case_id] + ct_aggregate_case) + ctl_sum
end

point_estimate /= nC
gradient .-= point_estimate

Expand All @@ -234,7 +248,10 @@ end

get_fluctuation_weights(prevalence_weights::Nothing, clever_covariate_weights) = clever_covariate_weights

get_fluctuation_weights(prevalence_weights, clever_covariate_weights) = clever_covariate_weights .* prevalence_weights
function get_fluctuation_weights(prevalence_weights, clever_covariate_weights)
normalised_prevalence_weights = (prevalence_weights / sum(prevalence_weights)) * length(prevalence_weights)
return clever_covariate_weights .* normalised_prevalence_weights
end

"""
MLJBase.fit(model::Fluctuation, verbosity, X, y)
Expand Down Expand Up @@ -307,7 +324,7 @@ end
Generates initial predictions and iteratively predicts from the fitted fluctuations.
"""
function MLJBase.predict(model::Fluctuation, machines, X)
covariate, _ = clever_covariate_and_weights(
covariate, _, _ = clever_covariate_and_weights(
model.Ψ, model.initial_factors.propensity_score, X;
ps_lowerbound=model.ps_lowerbound,
weighted_fluctuation=model.weighted
Expand Down
9 changes: 5 additions & 4 deletions src/counterfactual_mean_based/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,17 @@ Computes the projection of the gradient on the (Y | X) space.

This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached.
"""
function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8)
function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=nothing)
# Maybe can cache some results (H and E[Y|X]) to improve perf here
H, w = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound)
H, w, signs = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound)
y = float(dataset[!, Q.estimand.outcome])
Ey = expected_value(Q, dataset)
return ∇YX(H, y, Ey, w)
H_combined = H * signs
return ∇YX(H_combined, y, Ey, w)
end


function gradient_and_plugin_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8)
function gradient_and_plugin_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=nothing)
Q = factors.outcome_mean
G = factors.propensity_score
ctf_agg = counterfactual_aggregate(Ψ, Q, dataset)
Expand Down
4 changes: 2 additions & 2 deletions src/counterfactual_mean_based/nuisance_estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ end
function CMBasedFoldsTMLE(Ψ, initial_factors_estimate, train_validation_indices;
tol=nothing,
max_iter=1,
ps_lowerbound=1e-8,
ps_lowerbound=nothing,
weighted=false,
machine_cache=false,
)
Expand Down Expand Up @@ -461,7 +461,7 @@ function get_targeted_estimator(
initial_factors_estimate;
tol=nothing,
max_iter=1,
ps_lowerbound=1e-8,
ps_lowerbound=nothing,
weighted=true,
machine_cache=false,
models=nothing,
Expand Down
Loading
Loading