From ce0e8783bf569cfa9b1cd3128fd137d7b24fdb57 Mon Sep 17 00:00:00 2001 From: joshua-slaughter Date: Wed, 8 Apr 2026 23:55:06 +0100 Subject: [PATCH 1/7] normalise weights; default data-apadtive propensity score truncation; update tests --- .../clever_covariate.jl | 24 +++++++---- src/counterfactual_mean_based/estimators.jl | 36 +++++++++------- src/counterfactual_mean_based/fluctuation.jl | 13 ++++-- src/counterfactual_mean_based/gradient.jl | 4 +- .../nuisance_estimators.jl | 4 +- src/estimators.jl | 10 +++-- src/utils.jl | 2 +- test/counterfactual_mean_based/fluctuation.jl | 41 ++++++++++--------- test/estimators_and_estimates.jl | 29 ++++++------- 9 files changed, 95 insertions(+), 68 deletions(-) diff --git a/src/counterfactual_mean_based/clever_covariate.jl b/src/counterfactual_mean_based/clever_covariate.jl index 13e9880a..479b089e 100644 --- a/src/counterfactual_mean_based/clever_covariate.jl +++ b/src/counterfactual_mean_based/clever_covariate.jl @@ -1,9 +1,15 @@ """ - data_adaptive_ps_lower_bound(Ψ::StatisticalCMCompositeEstimand) + data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) -This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false) -but the study does not show strictly better behaviour of the strategy so not a default for now. +Data-adaptive propensity score truncation level from Gruber et al. (2022): +"Data-Adaptive Selection of the Propensity Score Truncation Level for +Inverse-Probability–Weighted and Targeted Maximum Likelihood Estimators +of Marginal Point Treatment Effects" (doi:10.1093/aje/kwac087). + +This sets the propensity score lower bound to `5/(√n * log(n/5))`, capped at `max_lb`. +The paper formula is `5/(√n * ln(n))` but uses a slightly modified version here. +This is the default when `ps_lowerbound=nothing`. """ data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = min(5 / (√(n)*log(n/5)), max_lb) @@ -18,12 +24,14 @@ function truncate!(v::AbstractVector, ps_lowerbound::AbstractFloat) end end -function balancing_weights(G, dataset; ps_lowerbound=1e-8) - jointlikelihood = ones(nrows(dataset)) +function balancing_weights(G, dataset; ps_lowerbound=nothing) + n = nrows(dataset) + jointlikelihood = ones(n) for Gᵢ ∈ G.components jointlikelihood .*= likelihood(Gᵢ, dataset) end - truncate!(jointlikelihood, ps_lowerbound) + actual_lowerbound = ps_lower_bound(n, ps_lowerbound) + truncate!(jointlikelihood, actual_lowerbound) return 1. ./ jointlikelihood end @@ -32,7 +40,7 @@ end Ψ::StatisticalCMCompositeEstimand, Gs::Tuple{Vararg{ConditionalDistributionEstimate}}, dataset; - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted_fluctuation=false ) @@ -54,7 +62,7 @@ function clever_covariate_and_weights( Ψ::StatisticalCMCompositeEstimand, G, dataset; - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted_fluctuation=false ) # Compute the indicator values diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 467f647f..89ade308 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -12,6 +12,7 @@ mutable struct Tmle <: Estimator max_iter::Int machine_cache::Bool prevalence::Union{Nothing, Float64} + normalise_weights::Bool function Tmle( models, resampling, @@ -21,7 +22,8 @@ mutable struct Tmle <: Estimator tol, max_iter, machine_cache, - prevalence + prevalence, + normalise_weights ) if resampling === nothing && collaborative_strategy !== nothing @warn("Collaborative TMLE requires a resampling strategy but none was provided. Using the default resampling strategy.") @@ -35,13 +37,14 @@ mutable struct Tmle <: Estimator weighted, tol, max_iter, machine_cache, - prevalence + prevalence, + normalise_weights ) end end """ - Tmle(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing, machine_cache=false) + Tmle(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, weighted=false, tol=nothing, machine_cache=false) Defines a TMLE estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -52,8 +55,9 @@ function that can be applied to estimate estimands for a dataset. - collaborative_strategy (default: nothing): A collaborative strategy to use for the estimation. Then the resampling strategy is used to evaluate the candidates. - resampling (default: `default_resampling(collaborative_strategy)`): Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla TMLE while any valid `MLJ.ResamplingStrategy` will result in CV-TMLE. -- ps_lowerbound (default: 1e-8): Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will -result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/). +- ps_lowerbound (default: nothing): Lowerbound for the propensity score to avoid division by 0. The default `nothing` +uses data-adaptive truncation as described in [Gruber et al. (2022)](https://pubmed.ncbi.nlm.nih.gov/35512316/): `5/(√n * log(n/5))`. +A fixed value can also be provided. - weighted (default: false): Whether the fluctuation model is a classig GLM or a weighted version. The weighted fluctuation has been show to be more robust to positivity violation in practice. - tol (default: nothing): Convergence threshold for the TMLE algorithm iterations. If nothing (default), 1/(sample size) will be used. See also `max_iter`. @@ -81,12 +85,13 @@ function Tmle(; models=default_models(), collaborative_strategy=nothing, resampling=default_resampling(collaborative_strategy), - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted=true, tol=nothing, max_iter=1, machine_cache=false, - prevalence=nothing + prevalence=nothing, + normalise_weights=true ) Tmle( models, @@ -96,7 +101,8 @@ function Tmle(; weighted, tol, max_iter, machine_cache, - prevalence + prevalence, + normalise_weights ) end @@ -117,7 +123,7 @@ function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), prevalence=tmle.prevalence ) - prevalence_weights = compute_prevalence_weights(tmle.prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome]) + prevalence_weights = compute_prevalence_weights(tmle.prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome], normalisation = tmle.normalise_weights) initial_factors_estimator = CMRelevantFactorsEstimator(tmle.collaborative_strategy; train_validation_indices=train_validation_indices, models=tmle.models, @@ -167,7 +173,7 @@ function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), return TMLEstimate(Ψ, Ψ̂, σ̂, n, IC), cache end -gradient_and_estimate(::Tmle, Ψ, factors, dataset; ps_lowerbound=1e-8) = +gradient_and_estimate(::Tmle, Ψ, factors, dataset; ps_lowerbound=nothing) = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound) ##################################################################### @@ -182,7 +188,7 @@ mutable struct Ose <: Estimator end """ - Ose(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) + Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false) Defines a One Step Estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -192,8 +198,8 @@ function that can be applied to estimate estimands for a dataset. - models: A Dict(variable => model, ...) where the `variables` are the outcome variables modeled by the `models`. - resampling: Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla estimation while any valid `MLJ.ResamplingStrategy` will result in CV-OSE. -- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will -result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/). +- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` (default) will +result in a data adaptive definition as described in [Gruber et al. (2022)](https://pubmed.ncbi.nlm.nih.gov/35512316/). - machine_cache: Whether MLJ.machine created during estimation should cache data. # Run Argument @@ -213,7 +219,7 @@ ose = Ose() Ψ̂ₙ, cache = ose(Ψ, dataset) ``` """ -Ose(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) = +Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false) = Ose(models, resampling, ps_lowerbound, machine_cache) function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1()) @@ -248,7 +254,7 @@ function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), v return OSEstimate(Ψ, Ψ̂, σ̂, n, IC), cache end -function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=1e-8) +function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=nothing) IC, Ψ̂ = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound) IC_mean = mean(IC) IC .-= IC_mean diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index dd7c2b51..1017e06d 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -3,13 +3,13 @@ mutable struct Fluctuation <: MLJBase.Supervised initial_factors::MLCMRelevantFactors tol::Union{Nothing, Float64} max_iter::Int - ps_lowerbound::Float64 + ps_lowerbound::Union{Float64, Nothing} weighted::Bool cache::Bool prevalence_weights::Union{Nothing, Vector{Float64}} end -Fluctuation(Ψ, initial_factors; tol=nothing, max_iter=1, ps_lowerbound=1e-8, weighted=false, cache=false, prevalence_weights=nothing) = +Fluctuation(Ψ, initial_factors; tol=nothing, max_iter=1, ps_lowerbound=nothing, weighted=false, cache=false, prevalence_weights=nothing) = Fluctuation(Ψ, initial_factors, tol, max_iter, ps_lowerbound, weighted, cache, prevalence_weights) one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:MLJBase.Continuous} = LinearRegressor(fit_intercept=false, offsetcol = :offset) @@ -219,7 +219,14 @@ function gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) ctl_sum = q̄₀_over_J * sum(ct_aggregate_controls_batch) gradient[case_id] = q₀ * (gradient_Y_X_cases[case_id] + ct_aggregate_case) + ctl_sum end - point_estimate /= nC + # Normalize by sum of weights - this gives the correct point estimate + # for both normalized and unnormalized weights + weight_sum = sum(weights) + point_estimate /= weight_sum + # Rescale gradient to match: when weights are normalized, gradient components + # are scaled by n/nC, so we need to scale them back by nC/n = nC/weight_sum + scale_factor = nC / weight_sum + gradient .*= scale_factor gradient .-= point_estimate return gradient, point_estimate diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index 11f6466a..fcea6802 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -52,7 +52,7 @@ Computes the projection of the gradient on the (Y | X) space. This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ -function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8) +function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=nothing) # Maybe can cache some results (H and E[Y|X]) to improve perf here H, w = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound) y = float(dataset[!, Q.estimand.outcome]) @@ -61,7 +61,7 @@ function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound= end -function gradient_and_plugin_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8) +function gradient_and_plugin_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=nothing) Q = factors.outcome_mean G = factors.propensity_score ctf_agg = counterfactual_aggregate(Ψ, Q, dataset) diff --git a/src/counterfactual_mean_based/nuisance_estimators.jl b/src/counterfactual_mean_based/nuisance_estimators.jl index 319c4ba6..825a75be 100644 --- a/src/counterfactual_mean_based/nuisance_estimators.jl +++ b/src/counterfactual_mean_based/nuisance_estimators.jl @@ -320,7 +320,7 @@ end function CMBasedFoldsTMLE(Ψ, initial_factors_estimate, train_validation_indices; tol=nothing, max_iter=1, - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted=false, machine_cache=false, ) @@ -461,7 +461,7 @@ function get_targeted_estimator( initial_factors_estimate; tol=nothing, max_iter=1, - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted=true, machine_cache=false, models=nothing, diff --git a/src/estimators.jl b/src/estimators.jl index 6c489346..e10aedfd 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -64,16 +64,20 @@ Calculates weights for a case-control study to use in the fitting of nuisance fu - `prevalence`: The prevalence of the outcome in the population. - `y`: The outcome variable across observations, which should be binary vector.` """ -function compute_prevalence_weights(prevalence::Float64, y::AbstractVector) +function compute_prevalence_weights(prevalence::Float64, y::AbstractVector; normalisation=true) J = sum(y .== 0) ÷ sum(y .== 1) weights = Vector{Float64}(undef, length(y)) for i in eachindex(y) weights[i] = y[i] == 1 ? prevalence : (1 - prevalence) / J end - return weights + if normalisation + return (weights/sum(weights))*length(weights) + else + return weights + end end -compute_prevalence_weights(::Nothing, y) = nothing +compute_prevalence_weights(::Nothing, y; normalisation=true) = nothing get_training_prevalence_weights(::Nothing, train_indices) = nothing diff --git a/src/utils.jl b/src/utils.jl index c54d4280..6f6af55c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,7 +35,7 @@ This is to avoid the expensive complications of: - Tracking sample_ids """ function choose_initial_dataset(dataset, fluctuation_dataset; train_validation_indices=nothing, prevalence=nothing) - # In CV mode or prevalence mode, we get back to the no fluctuation_dataset + # In CV mode or case-control weighted mode, we get back to the no fluctuation_dataset if !isnothing(train_validation_indices) || !isnothing(prevalence) return fluctuation_dataset else diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index 47c33147..c9e56ffa 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -190,7 +190,7 @@ end X = dataset[!, collect(η̂ₙ.outcome_mean.estimand.parents)] y = dataset[!, η̂ₙ.outcome_mean.estimand.outcome] - fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; weighted=false, prevalence_weights=prevalence_weights, max_iter=5) + fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; weighted=false, prevalence_weights=prevalence_weights, max_iter=5, ps_lowerbound=nothing) machs, cache, report = MLJBase.fit(fluctuation, 0, X, y) gradient = report.gradients[1] @test mean(gradient) ≈ 0.0 atol=1e-4 @@ -205,17 +205,18 @@ end q_0 = 0.8 q_0_bar_over_J = 0.2 gradient, point_estimate = TMLE.gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) - # The 7ths control is not used in these computations - @test point_estimate == 0.5*( - (q_0 * 2) + q_0_bar_over_J * (1 + 3) # First case (idx=2) grouped with controls (idx=[1, 3]) - + - (q_0 * 4) + q_0_bar_over_J * (5 + 6) # Second case (idx=4) grouped with controls (idx=[5, 6]) - ) - # gradient_Y_X and ct_aggregate are summed together and the point estimate is removed - @test gradient == [ - (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1) + (3 + 3)), # First case (idx=2) grouped with controls (idx=[1, 3]) - (q_0 * (4 + 4)) + q_0_bar_over_J * ((5 + 5) + (6 + 6)) # Second case (idx=4) grouped with controls (idx=[5, 6]) - ] .- point_estimate + # The 7th control is not used in these computations + # With new formula: point_estimate = raw_sum / sum(weights), gradient scaled by nC/sum(weights) + raw_point_sum = (q_0 * 2) + q_0_bar_over_J * (1 + 3) + (q_0 * 4) + q_0_bar_over_J * (5 + 6) + @test point_estimate ≈ raw_point_sum / sum(weights) atol=1e-10 + # gradient components scaled by nC / sum(weights) = 2/2.6, then point_estimate subtracted + nC = 2 + scale_factor = nC / sum(weights) + raw_gradient = [ + (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1) + (3 + 3)), + (q_0 * (4 + 4)) + q_0_bar_over_J * ((5 + 5) + (6 + 6)) + ] + @test gradient ≈ raw_gradient .* scale_factor .- point_estimate atol=1e-10 # When there is exactly J controls per case, all controls are used ct_aggregate = [1, 2, 3, 4] @@ -225,17 +226,17 @@ end q_0 = 0.8 q_0_bar_over_J = 0.2 gradient, point_estimate = TMLE.gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) - # The 7ths control is not used in these computations + # sum(weights) = 2.0 = nC, so scale_factor = 1.0 and old/new formulas agree @test point_estimate ≈ 0.5*( - (q_0 * 2) + (q_0_bar_over_J * 1) # First case (idx=2) grouped with controls (idx=[1, 3]) + (q_0 * 2) + (q_0_bar_over_J * 1) + - (q_0 * 4) + (q_0_bar_over_J * 3) # Second case (idx=4) grouped with controls (idx=[5, 6]) + (q_0 * 4) + (q_0_bar_over_J * 3) ) atol=1e-10 - # gradient_Y_X and ct_aggregate are summed together and the point estimate is removed - @test gradient == [ - (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1)), # First case (idx=2) grouped with controls (idx=[1, 3]) - (q_0 * (4 + 4)) + q_0_bar_over_J * ((3 + 3)) # Second case (idx=4) grouped with controls (idx=[5, 6]) - ] .- point_estimate + # gradient_Y_X and ct_aggregate are summed together and the point estimate is removed + @test gradient ≈ [ + (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1)), + (q_0 * (4 + 4)) + q_0_bar_over_J * ((3 + 3)) + ] .- point_estimate atol=1e-10 end end diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 9fbc8fe9..2c054c2a 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -26,7 +26,7 @@ reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) cache = Dict() # Model that supports weights estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(),prevalence_weights=weights) - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) # Model that does NOT support weights (e.g., LogisticClassifier) estimator = TMLE.MLConditionalDistributionEstimator(LogisticClassifier(), prevalence_weights=weights) @@ -34,7 +34,7 @@ reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) # Pipeline that supports weights estimator = TMLE.MLConditionalDistributionEstimator(with_encoder(LinearBinaryClassifier()), prevalence_weights=weights) - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) # Pipeline that does NOT support weights estimator = TMLE.MLConditionalDistributionEstimator(with_encoder(LogisticClassifier()), prevalence_weights=weights) @@ -93,33 +93,34 @@ end end @testset "Test MLConditionalDistributionEstimator: binary outcome with prevalence weights" begin - # Simulate a binary outcome with imbalanced classes - n = 100 - X, y = make_moons(n) - y = copy(y) + # Binary outcome with imbalanced classes + y = copy(binary_dataset[!, :Y]) y[1:80] .= 0 # Make one class more prevalent y[81:100] .= 1 - binary_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) + case_control_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) # Set prevalence to 0.5 (true prevalence in population) prevalence = 0.5 - weights = TMLE.compute_prevalence_weights(prevalence, binary_dataset.Y) - @test Set(weights) == Set([0.5, 0.125]) # Check weights are correct + weights = TMLE.compute_prevalence_weights(prevalence, case_control_dataset.Y, normalisation = false) + @test Set(weights) == Set([0.5, 0.125]) # Check weights are correct (non-normalised) + weights = TMLE.compute_prevalence_weights(prevalence, case_control_dataset.Y) + @test sum(weights) == n #Check normalised weights sum to n + estimand = TMLE.ConditionalDistribution(:Y, [:X₁, :X₂]) estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(), nothing, weights) # Fit with prevalence weights cache = Dict() - conditional_density_estimate = estimator(estimand, binary_dataset; cache=cache, verbosity=1) + conditional_density_estimate = estimator(estimand, case_control_dataset; cache=cache, verbosity=1) @test conditional_density_estimate isa TMLE.MLConditionalDistribution # Check that predictions are probabilities - ŷ = MLJBase.predict(conditional_density_estimate, binary_dataset) + ŷ = MLJBase.predict(conditional_density_estimate, case_control_dataset) @test all(0.0 .<= [ŷ[i].prob_given_ref[2] for i in eachindex(ŷ)] .<= 1.0) # Check that weights are used (by comparing with unweighted fit) estimator_unweighted = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) - conditional_density_estimate_unweighted = estimator_unweighted(estimand, binary_dataset; cache=Dict(), verbosity=0) - ŷ_unweighted = MLJBase.predict(conditional_density_estimate_unweighted, binary_dataset) + conditional_density_estimate_unweighted = estimator_unweighted(estimand, case_control_dataset; cache=Dict(), verbosity=0) + ŷ_unweighted = MLJBase.predict(conditional_density_estimate_unweighted, case_control_dataset) μ̂_weighted = [ŷ[i].prob_given_ref[2] for i in eachindex(ŷ)] μ̂_unweighted = [ŷ_unweighted[i].prob_given_ref[2] for i in eachindex(ŷ_unweighted)] @test !all(isapprox.(μ̂_weighted, μ̂_unweighted; atol=1e-6)) # Should differ @@ -219,7 +220,7 @@ end prevalence = 0.5 binary_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) weights = TMLE.compute_prevalence_weights(prevalence, binary_dataset.Y) - @test Set(weights) == Set([0.5, 0.125]) # Check weights are correct + @test Set(weights) == Set([2.5, 0.625]) # Check normalized weights are correct nfolds = 3 train_validation_indices = Tuple(MLJBase.train_test_pairs(StratifiedCV(nfolds=nfolds), 1:n, binary_dataset, binary_dataset.Y)) estimand = TMLE.ConditionalDistribution(:Y, [:X₁, :X₂]) From f46e745e52b1cdc8a8cbbccc5c9839ee9cc415b9 Mon Sep 17 00:00:00 2001 From: joshua-slaughter Date: Thu, 16 Apr 2026 10:50:28 +0100 Subject: [PATCH 2/7] change truncation and normalisation strategy --- .../clever_covariate.jl | 7 ++- src/counterfactual_mean_based/estimators.jl | 17 +++----- src/counterfactual_mean_based/fluctuation.jl | 15 +++---- src/estimators.jl | 15 +++---- .../clever_covariate.jl | 43 ++++++++++++++++++- test/counterfactual_mean_based/fluctuation.jl | 16 +++---- test/estimators_and_estimates.jl | 6 +-- 7 files changed, 75 insertions(+), 44 deletions(-) diff --git a/src/counterfactual_mean_based/clever_covariate.jl b/src/counterfactual_mean_based/clever_covariate.jl index 479b089e..76dd350c 100644 --- a/src/counterfactual_mean_based/clever_covariate.jl +++ b/src/counterfactual_mean_based/clever_covariate.jl @@ -7,12 +7,11 @@ Data-adaptive propensity score truncation level from Gruber et al. (2022): Inverse-Probability–Weighted and Targeted Maximum Likelihood Estimators of Marginal Point Treatment Effects" (doi:10.1093/aje/kwac087). -This sets the propensity score lower bound to `5/(√n * log(n/5))`, capped at `max_lb`. -The paper formula is `5/(√n * ln(n))` but uses a slightly modified version here. -This is the default when `ps_lowerbound=nothing`. +This is the default when `ps_lowerbound=nothing`. Here a maximum lower bound +is applied to prevent extreme truncation in small samples. """ data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = - min(5 / (√(n)*log(n/5)), max_lb) + min(5 / (√(n)*log(n)), max_lb) ps_lower_bound(n::Int, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(n; max_lb=max_lb) ps_lower_bound(n::Int, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 89ade308..c5744c65 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -12,7 +12,6 @@ mutable struct Tmle <: Estimator max_iter::Int machine_cache::Bool prevalence::Union{Nothing, Float64} - normalise_weights::Bool function Tmle( models, resampling, @@ -22,8 +21,7 @@ mutable struct Tmle <: Estimator tol, max_iter, machine_cache, - prevalence, - normalise_weights + prevalence ) if resampling === nothing && collaborative_strategy !== nothing @warn("Collaborative TMLE requires a resampling strategy but none was provided. Using the default resampling strategy.") @@ -37,8 +35,7 @@ mutable struct Tmle <: Estimator weighted, tol, max_iter, machine_cache, - prevalence, - normalise_weights + prevalence ) end end @@ -90,8 +87,7 @@ function Tmle(; tol=nothing, max_iter=1, machine_cache=false, - prevalence=nothing, - normalise_weights=true + prevalence=nothing ) Tmle( models, @@ -101,14 +97,15 @@ function Tmle(; weighted, tol, max_iter, machine_cache, - prevalence, - normalise_weights + prevalence ) end function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1()) # Check if the inputs are suitable for the specified estimand check_inputs(Ψ, dataset, tmle.prevalence) + # Reset collaborative strategy state before building relevant factors + tmle.collaborative_strategy !== nothing && initialise!(tmle.collaborative_strategy, Ψ) # Make train-validation pairs train_validation_indices = get_train_validation_indices(tmle.resampling, Ψ, dataset) # Initial fit of the SCM's relevant factors @@ -123,7 +120,7 @@ function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), prevalence=tmle.prevalence ) - prevalence_weights = compute_prevalence_weights(tmle.prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome], normalisation = tmle.normalise_weights) + prevalence_weights = compute_prevalence_weights(tmle.prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome]) initial_factors_estimator = CMRelevantFactorsEstimator(tmle.collaborative_strategy; train_validation_indices=train_validation_indices, models=tmle.models, diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index 1017e06d..4953eebc 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -219,14 +219,8 @@ function gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) ctl_sum = q̄₀_over_J * sum(ct_aggregate_controls_batch) gradient[case_id] = q₀ * (gradient_Y_X_cases[case_id] + ct_aggregate_case) + ctl_sum end - # Normalize by sum of weights - this gives the correct point estimate - # for both normalized and unnormalized weights - weight_sum = sum(weights) - point_estimate /= weight_sum - # Rescale gradient to match: when weights are normalized, gradient components - # are scaled by n/nC, so we need to scale them back by nC/n = nC/weight_sum - scale_factor = nC / weight_sum - gradient .*= scale_factor + + point_estimate /= nC gradient .-= point_estimate return gradient, point_estimate @@ -241,7 +235,10 @@ end get_fluctuation_weights(prevalence_weights::Nothing, clever_covariate_weights) = clever_covariate_weights -get_fluctuation_weights(prevalence_weights, clever_covariate_weights) = clever_covariate_weights .* prevalence_weights +function get_fluctuation_weights(prevalence_weights, clever_covariate_weights) + normalised_prevalence_weights = (prevalence_weights / sum(prevalence_weights)) * length(prevalence_weights) + return clever_covariate_weights .* normalised_prevalence_weights +end """ MLJBase.fit(model::Fluctuation, verbosity, X, y) diff --git a/src/estimators.jl b/src/estimators.jl index e10aedfd..f33247d4 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -48,7 +48,9 @@ function fit_mlj_model(model, X, y; parents=names(X), cache=false, weights=nothi mach = machine(model, X, y; cache=cache) else if supervised_learner_supports_weights(model) - mach = machine(model, X, y, weights; cache=cache) + # Normalise weights at point-of-use + normalised_weights = (weights / sum(weights)) * length(weights) + mach = machine(model, X, y, normalised_weights; cache=cache) else throw(ArgumentError("The model $(model) does not support weights and cannot be used with prevalence.")) end @@ -61,23 +63,20 @@ end compute_prevalence_weights(prevalence, y) Calculates weights for a case-control study to use in the fitting of nuisance functions. +Returns raw (unnormalised) weights. Normalisation happens at point-of-use in fit_mlj_model and fluctuation. - `prevalence`: The prevalence of the outcome in the population. - `y`: The outcome variable across observations, which should be binary vector.` """ -function compute_prevalence_weights(prevalence::Float64, y::AbstractVector; normalisation=true) +function compute_prevalence_weights(prevalence::Float64, y::AbstractVector) J = sum(y .== 0) ÷ sum(y .== 1) weights = Vector{Float64}(undef, length(y)) for i in eachindex(y) weights[i] = y[i] == 1 ? prevalence : (1 - prevalence) / J end - if normalisation - return (weights/sum(weights))*length(weights) - else - return weights - end + return weights end -compute_prevalence_weights(::Nothing, y; normalisation=true) = nothing +compute_prevalence_weights(::Nothing, y) = nothing get_training_prevalence_weights(::Nothing, train_indices) = nothing diff --git a/test/counterfactual_mean_based/clever_covariate.jl b/test/counterfactual_mean_based/clever_covariate.jl index ebb399f3..f445b1e1 100644 --- a/test/counterfactual_mean_based/clever_covariate.jl +++ b/test/counterfactual_mean_based/clever_covariate.jl @@ -12,7 +12,7 @@ using DataFrames # Nothing results in data adaptive lower bound no lower than max_lb @test TMLE.ps_lower_bound(n, nothing) == max_lb @test TMLE.data_adaptive_ps_lower_bound(n) == max_lb - @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 + @test TMLE.data_adaptive_ps_lower_bound(1000) ≈ 5 / (√1000 * log(1000)) # Otherwise use the provided threhsold provided it's lower than max_lb @test TMLE.ps_lower_bound(n, 1e-8) == 1.0e-8 @test TMLE.ps_lower_bound(n, 1) == 0.1 @@ -137,6 +137,47 @@ end @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] end +@testset "Test clever_covariate_and_weights: Data-Adapative 1 treatment" begin + Ψ = ATE( + outcome=:Y, + treatment_values=(T=(case="a", control="b"),), + treatment_confounders=(T=[:W],), + ) + n = 150 + dataset = DataFrame( + T = categorical(vcat(fill("a", 149), ["b"])), + Y = collect(1.0:n), + W = rand(n), + ) + + propensity_score_estimator = TMLE.JointConditionalDistributionEstimator(Dict(:T => TMLE.MLConditionalDistributionEstimator(ConstantClassifier()))) + propensity_score_estimate = propensity_score_estimator( + (TMLE.ConditionalDistribution(:T, [:W]),), + dataset, + verbosity=0 + ) + + # Se + weighted_fluctuation = true + ps_lowerbound = nothing + cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + + @test w[1:149] ≈ fill(1/(149/150), 149) + @test w[150] ≈ 1/(5/(sqrt(n)*log(n))) + + weighted_fluctuation = false + cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test cov[1:149] ≈ fill(1/(149/150), 149) + @test cov[150] ≈ -1/(5/(sqrt(n)*log(n))) + @test w == ones(150) +end + end true \ No newline at end of file diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index c9e56ffa..e71961d6 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -205,18 +205,18 @@ end q_0 = 0.8 q_0_bar_over_J = 0.2 gradient, point_estimate = TMLE.gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) - # The 7th control is not used in these computations - # With new formula: point_estimate = raw_sum / sum(weights), gradient scaled by nC/sum(weights) - raw_point_sum = (q_0 * 2) + q_0_bar_over_J * (1 + 3) + (q_0 * 4) + q_0_bar_over_J * (5 + 6) - @test point_estimate ≈ raw_point_sum / sum(weights) atol=1e-10 - # gradient components scaled by nC / sum(weights) = 2/2.6, then point_estimate subtracted + # The 7th control is not used in these computations (J = 5÷2 = 2 controls per case) + # Cases at indices 2, 4; Controls at indices 1, 3, 5, 6, 7 (7th dropped) + # point_estimate = sum over cases of (q₀*ct_case + q̄₀/J*sum(ct_controls)) / nC nC = 2 - scale_factor = nC / sum(weights) + raw_point_sum = (q_0 * 2) + q_0_bar_over_J * (1 + 3) + (q_0 * 4) + q_0_bar_over_J * (5 + 6) + @test point_estimate ≈ raw_point_sum / nC atol=1e-10 + # gradient = q₀*(ct_case + grad_case) + q̄₀/J*sum(ct_ctl + grad_ctl) - point_estimate raw_gradient = [ (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1) + (3 + 3)), (q_0 * (4 + 4)) + q_0_bar_over_J * ((5 + 5) + (6 + 6)) ] - @test gradient ≈ raw_gradient .* scale_factor .- point_estimate atol=1e-10 + @test gradient ≈ raw_gradient .- point_estimate atol=1e-10 # When there is exactly J controls per case, all controls are used ct_aggregate = [1, 2, 3, 4] @@ -226,7 +226,7 @@ end q_0 = 0.8 q_0_bar_over_J = 0.2 gradient, point_estimate = TMLE.gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) - # sum(weights) = 2.0 = nC, so scale_factor = 1.0 and old/new formulas agree + # Cases at indices 2, 4; Controls at indices 1, 3 (J = 2÷2 = 1 control per case) @test point_estimate ≈ 0.5*( (q_0 * 2) + (q_0_bar_over_J * 1) + diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 2c054c2a..aeedef52 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -100,10 +100,8 @@ end case_control_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) # Set prevalence to 0.5 (true prevalence in population) prevalence = 0.5 - weights = TMLE.compute_prevalence_weights(prevalence, case_control_dataset.Y, normalisation = false) - @test Set(weights) == Set([0.5, 0.125]) # Check weights are correct (non-normalised) weights = TMLE.compute_prevalence_weights(prevalence, case_control_dataset.Y) - @test sum(weights) == n #Check normalised weights sum to n + @test Set(weights) == Set([0.5, 0.125]) # Weights should not be normalised estimand = TMLE.ConditionalDistribution(:Y, [:X₁, :X₂]) estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(), nothing, weights) @@ -220,7 +218,7 @@ end prevalence = 0.5 binary_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) weights = TMLE.compute_prevalence_weights(prevalence, binary_dataset.Y) - @test Set(weights) == Set([2.5, 0.625]) # Check normalized weights are correct + @test Set(weights) == Set([0.5, 0.125]) nfolds = 3 train_validation_indices = Tuple(MLJBase.train_test_pairs(StratifiedCV(nfolds=nfolds), 1:n, binary_dataset, binary_dataset.Y)) estimand = TMLE.ConditionalDistribution(:Y, [:X₁, :X₂]) From 258327e46c01b153d28920e6ec089089e0bb8fad Mon Sep 17 00:00:00 2001 From: joshua-slaughter Date: Thu, 16 Apr 2026 16:55:10 +0100 Subject: [PATCH 3/7] add multi-dimensional fluctuation parameter; update tests as the change in fluctuation leads to slightly different estimates; change greedy and adaptive tests as different confounders are now selected --- .../clever_covariate.jl | 27 ++++++++++------- src/counterfactual_mean_based/fluctuation.jl | 29 ++++++++++++++----- src/counterfactual_mean_based/gradient.jl | 5 ++-- src/utils.jl | 25 ++++++++++++++++ .../clever_covariate.jl | 28 ++++++++++-------- .../covariate_based_strategies.jl | 26 ++++++++--------- .../double_robustness_aie.jl | 11 +++---- test/counterfactual_mean_based/fluctuation.jl | 19 ++++++------ .../non_regression_test.jl | 6 ++-- 9 files changed, 113 insertions(+), 63 deletions(-) diff --git a/src/counterfactual_mean_based/clever_covariate.jl b/src/counterfactual_mean_based/clever_covariate.jl index 76dd350c..04dff49f 100644 --- a/src/counterfactual_mean_based/clever_covariate.jl +++ b/src/counterfactual_mean_based/clever_covariate.jl @@ -43,19 +43,24 @@ end weighted_fluctuation=false ) -Computes the clever covariate and weights that are used to fluctuate the initial Q. +Computes the clever covariate matrix, weights, and signs used to fluctuate the initial Q. + +Returns a tuple `(H, w, signs)` where: +- `H` is an `n × K` matrix, one column per counterfactual in `indicator_fns(Ψ)`. +- `w` is a length-`n` weight vector. +- `signs` is a length-`K` vector of signs from `indicator_fns(Ψ)`. if `weighted_fluctuation = false`: -- ``clever_covariate(t, w) = \\frac{SpecialIndicator(t)}{p(t|w)}`` -- ``weight(t, w) = 1`` +- ``H_{k}(t, w) = \\frac{I_k(t)}{p(t|w)}`` +- ``w(t, w) = 1`` if `weighted_fluctuation = true`: -- ``clever_covariate(t, w) = SpecialIndicator(t)`` -- ``weight(t, w) = \\frac{1}{p(t|w)}`` +- ``H_{k}(t, w) = I_k(t)`` +- ``w(t, w) = \\frac{1}{p(t|w)}`` -where SpecialIndicator(t) is defined in `indicator_fns`. +where ``I_k(t)`` is the unsigned indicator for the k-th counterfactual. """ function clever_covariate_and_weights( Ψ::StatisticalCMCompositeEstimand, @@ -64,14 +69,14 @@ function clever_covariate_and_weights( ps_lowerbound=nothing, weighted_fluctuation=false ) - # Compute the indicator values + # Compute the indicator matrix (n×K) and signs (K,) T = selectcols(dataset, (p.estimand.outcome for p in G.components)) - indic_vals = indicator_values(indicator_fns(Ψ), T) + indic_mat, signs = indicator_matrix(indicator_fns(Ψ), T) weights = balancing_weights(G, dataset; ps_lowerbound=ps_lowerbound) if weighted_fluctuation - return indic_vals, weights + return indic_mat, weights, signs end # Vanilla unweighted fluctuation - indic_vals .*= weights - return indic_vals, ones(size(weights, 1)) + indic_mat .*= weights + return indic_mat, ones(size(weights, 1)), signs end \ No newline at end of file diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index 4953eebc..2d489b63 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -24,11 +24,23 @@ The GLM models require inputs of the same type, which sometimes is not the case same_type_df(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} = DataFrame(covariate=covariate, offset=convert(Vector{T1}, offset)) -function fluctuation_input(covariate, ŷ) - offset = compute_offset(ŷ) +function fluctuation_input(covariate::AbstractVector, ŷ) + offset = compute_offset(ŷ) return same_type_df(covariate, offset) end +function fluctuation_input(covariate::AbstractMatrix, ŷ) + offset = compute_offset(ŷ) + T = eltype(covariate) + typed_offset = eltype(offset) == T ? offset : convert(Vector{T}, offset) + K = size(covariate, 2) + df = DataFrame(:offset => typed_offset) + for k in 1:K + df[!, Symbol("H_", k)] = covariate[:, k] + end + return df +end + hasconverged(gradient, tol) = abs(mean(gradient)) < tol """ @@ -53,13 +65,13 @@ If prevalence weights are provided, they are applied to the weights and normaliz function initialize_observed_cache(model, X, y) Q⁰ = model.initial_factors.outcome_mean G⁰ = model.initial_factors.propensity_score - H, w = clever_covariate_and_weights( + H, w, signs = clever_covariate_and_weights( model.Ψ, G⁰, X; ps_lowerbound=model.ps_lowerbound, weighted_fluctuation=model.weighted ) - ŷ = MLJBase.predict(Q⁰, X) - return Dict{Symbol, Any}(:H => H, :w => w, :ŷ => ŷ, :y => float(y)) + ŷ = MLJBase.predict(Q⁰, X) + return Dict{Symbol, Any}(:H => H, :w => w, :signs => signs, :ŷ => ŷ, :y => float(y)) end """ @@ -84,7 +96,7 @@ function initialize_counterfactual_cache(model, X) T_ct = counterfactualTreatment(vals, Ttemplate) X_ct = DataFrame((;(Symbol(colname) => colname ∈ names(T_ct) ? T_ct[!, colname] : X[!, colname] for colname in names(X))...)) - covariates_ct, w_ct = clever_covariate_and_weights(Ψ, + covariates_ct, _, _ = clever_covariate_and_weights(Ψ, G⁰, X_ct; ps_lowerbound=model.ps_lowerbound, @@ -137,7 +149,8 @@ function compute_gradient_and_estimate_from_caches!( # Compute gradient Ey = expected_value(observed_cache[:ŷ]) ct_aggregate = compute_counterfactual_aggregate!(counterfactual_cache, Q) - gradient_Y_X = ∇YX(observed_cache[:H], observed_cache[:y], Ey, observed_cache[:w]) + H_combined = observed_cache[:H] * observed_cache[:signs] + gradient_Y_X = ∇YX(H_combined, observed_cache[:y], Ey, observed_cache[:w]) gradient, Ψ̂ = gradient_and_estimate(ct_aggregate, gradient_Y_X, observed_cache[:y], prevalence_weights) return gradient, Ψ̂ end @@ -311,7 +324,7 @@ end Generates initial predictions and iteratively predicts from the fitted fluctuations. """ function MLJBase.predict(model::Fluctuation, machines, X) - covariate, _ = clever_covariate_and_weights( + covariate, _, _ = clever_covariate_and_weights( model.Ψ, model.initial_factors.propensity_score, X; ps_lowerbound=model.ps_lowerbound, weighted_fluctuation=model.weighted diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index fcea6802..57bf487f 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -54,10 +54,11 @@ This part of the gradient is evaluated on the original dataset. All quantities h """ function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=nothing) # Maybe can cache some results (H and E[Y|X]) to improve perf here - H, w = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound) + H, w, signs = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound) y = float(dataset[!, Q.estimand.outcome]) Ey = expected_value(Q, dataset) - return ∇YX(H, y, Ey, w) + H_combined = H * signs + return ∇YX(H_combined, y, Ey, w) end diff --git a/src/utils.jl b/src/utils.jl index 6f6af55c..c9bbab0a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -89,6 +89,31 @@ function indicator_values(indicators, T) return indic end +""" + indicator_matrix(indicators, T) + +Returns a tuple `(mat, signs)` where: +- `mat` is an `n × K` matrix of unsigned indicator columns (1.0 where treatment + values match, 0.0 otherwise), one column per entry in `indicators`. +- `signs` is a length-K vector of the corresponding signs from `indicators`. +""" +function indicator_matrix(indicators, T) + n = nrows(T) + indicator_keys = sort(collect(keys(indicators)), by=string) + signs = [indicators[k] for k in indicator_keys] + K = length(indicator_keys) + mat = zeros(Float64, n, K) + for (i, row) in enumerate(Tables.namedtupleiterator(T)) + vals = values(row) + for (j, key) in enumerate(indicator_keys) + if vals == key + mat[i, j] = 1.0 + end + end + end + return mat, signs +end + expected_value(ŷ::AbstractArray{<:UnivariateFinite{<:Union{OrderedFactor{2}, Multiclass{2}}}}) = pdf.(ŷ, levels(first(ŷ))[2]) expected_value(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ) expected_value(ŷ::AbstractVector{<:Real}) = ŷ diff --git a/test/counterfactual_mean_based/clever_covariate.jl b/test/counterfactual_mean_based/clever_covariate.jl index f445b1e1..df7ca277 100644 --- a/test/counterfactual_mean_based/clever_covariate.jl +++ b/test/counterfactual_mean_based/clever_covariate.jl @@ -38,20 +38,21 @@ end ) weighted_fluctuation = true ps_lowerbound = 1e-8 - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] + @test size(H) == (7, 2) + @test H * signs == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test w == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] weighted_fluctuation = false - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + @test H * signs == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] @test w == ones(7) end @@ -84,11 +85,12 @@ end ps_lowerbound = 1e-8 weighted_fluctuation = false - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 + @test size(H) == (7, 4) + @test H * signs ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] end @@ -129,11 +131,12 @@ end verbosity=0 ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 + @test size(H) == (7, 8) + @test H * signs ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] end @@ -160,7 +163,7 @@ end # Se weighted_fluctuation = true ps_lowerbound = nothing - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) @@ -169,12 +172,13 @@ end @test w[150] ≈ 1/(5/(sqrt(n)*log(n))) weighted_fluctuation = false - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov[1:149] ≈ fill(1/(149/150), 149) - @test cov[150] ≈ -1/(5/(sqrt(n)*log(n))) + combined = H * signs + @test combined[1:149] ≈ fill(1/(149/150), 149) + @test combined[150] ≈ -1/(5/(sqrt(n)*log(n))) @test w == ones(150) end diff --git a/test/counterfactual_mean_based/covariate_based_strategies.jl b/test/counterfactual_mean_based/covariate_based_strategies.jl index 5885adac..7fbc06f6 100644 --- a/test/counterfactual_mean_based/covariate_based_strategies.jl +++ b/test/counterfactual_mean_based/covariate_based_strategies.jl @@ -65,7 +65,7 @@ include(joinpath(TEST_DIR, "counterfactual_mean_based", "aie_simulations.jl")) @test adaptive_strategy.remaining_confounders == Set{Symbol}() @test adaptive_strategy.current_confounders == Set([:W₁, :W₂, :W₃]) - # Full run: this leads to only W₃ being used for the propensity score + # Full run: this leads to W₁ and W₃ being used for the propensity score ctmle = Tmle(collaborative_strategy=adaptive_strategy) Ψ = AIE( outcome = :Y, @@ -77,8 +77,8 @@ include(joinpath(TEST_DIR, "counterfactual_mean_based", "aie_simulations.jl")) ) result_ctmle, cache = ctmle(Ψ, dataset;verbosity=0); targeted_η̂ = cache[:targeted_factors] - @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₃)) - @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₃,)) + @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁, :W₃)) + @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₁, :W₃)) end @testset "Test GreedyStrategy Interface" begin @@ -158,19 +158,19 @@ end cache=cache, machine_cache=machine_cache, ) - @test new_g == (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁)), TMLE.ConditionalDistribution(:T₂, (:W₁,))) - @test new_ĝ == ĝ + @test new_g == (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₂)), TMLE.ConditionalDistribution(:T₂, ())) + @test new_ĝ == ĝ # Update the collaborative strategy - TMLE.update!(collaborative_strategy, new_g, new_ĝ) - @test collaborative_strategy.remaining_confounders == Set{Symbol}([:W₂, :W₃]) - @test collaborative_strategy.current_confounders == Set{Symbol}([:W₁]) + TMLE.update!(collaborative_strategy, new_g, new_ĝ) + @test collaborative_strategy.remaining_confounders == Set{Symbol}([:W₁, :W₃]) + @test collaborative_strategy.current_confounders == Set{Symbol}([:W₂]) # Let's iterate again step_k_candidate_iterator = TMLE.StepKPropensityScoreIterator(collaborative_strategy, Ψ, dataset, models, new_targeted_η̂ₙ) g_ĝ_candidates = collect(step_k_candidate_iterator) g_candidates = Set(first.(g_ĝ_candidates)) @test g_candidates == Set([ (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁, :W₂)), TMLE.ConditionalDistribution(:T₂, (:W₁,))), - (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁)), TMLE.ConditionalDistribution(:T₂, (:W₁, :W₃))) + (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₂)), TMLE.ConditionalDistribution(:T₂, (:W₃,))) ]) # Find optimal candidate again new_g, new_ĝ, new_targeted_η̂ₙ, new_loss, use_fluct = TMLE.step_k_best_candidate( @@ -187,7 +187,7 @@ end ) @test new_g == (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁, :W₂)), TMLE.ConditionalDistribution(:T₂, (:W₁,))) - # Full run: this leads to only W₃ being used for the propensity score + # Full run: this leads to only W₁ being used for the propensity score ctmle = Tmle(collaborative_strategy=collaborative_strategy) Ψ = AIE( outcome = :Y, @@ -199,8 +199,8 @@ end ) result_ctmle, cache = ctmle(Ψ, dataset;verbosity=0); targeted_η̂ = cache[:targeted_factors] - @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₃)) - @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₃,)) + @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁)) + @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₁,)) end @testset "Integration Test using the AdaptiveCorrelationStrategy" begin @@ -367,7 +367,7 @@ end @test fitted_propensity_score === new_g ## The loss should be smaller because we fluctuate through the previous model, however in finite samples ## I suppose this is not warranted, we check they are approximately equal and the loss with Q̄n,k,* < Q̄n,k - @test loss ≈ new_loss atol=1e-5 + @test loss ≈ new_loss atol=1e-4 # We can pretend the step_k_best_candidate had used the non targeted outcome model new_targeted_η̂ₙ_bis, new_loss_bis = TMLE.get_new_targeted_candidate(targeted_η̂ₙ, new_ĝₙ, fluctuation_model, dataset; use_fluct=false, diff --git a/test/counterfactual_mean_based/double_robustness_aie.jl b/test/counterfactual_mean_based/double_robustness_aie.jl index ab515d9e..9fd62544 100644 --- a/test/counterfactual_mean_based/double_robustness_aie.jl +++ b/test/counterfactual_mean_based/double_robustness_aie.jl @@ -13,7 +13,8 @@ include(joinpath(TEST_DIR, "helper_fns.jl")) include(joinpath(TEST_DIR, "counterfactual_mean_based", "aie_simulations.jl")) cont_interacter = InteractionTransformer(order=2) |> LinearRegressor -cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1.) +# remove regularization, last example was misspecified for Q and G +cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=0) @testset "Test Double Robustness AIE on binary_outcome_binary_treatment_pb" begin @@ -38,7 +39,7 @@ cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1. ) dr_estimators = double_robust_estimators(models, resampling=StratifiedCV()) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) - test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-6) test_mean_inf_curve_almost_zero(results.ose; atol=1e-9) # The initial estimate is far away naive = Plugin(models[:Y]) @@ -53,7 +54,7 @@ cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1. ) dr_estimators = double_robust_estimators(models, resampling=StratifiedCV()) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) - test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-7) test_mean_inf_curve_almost_zero(results.ose; atol=1e-9) # The initial estimate is far away naive = Plugin(models[:Y]) @@ -144,10 +145,10 @@ end test_mean_inf_curve_almost_zero(results.tmle; atol=1e-5) test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) - # The initial estimate is far away + # The initial plugin estimate is close to truth (Q is well specified) naive = Plugin(models[:Y]) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) - @test naive_result ≈ -0.02 atol=1e-2 + @test naive_result ≈ Ψ₀ atol=5e-2 end diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index e71961d6..e6910c9e 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -43,20 +43,20 @@ using MLJGLMInterface expected_value = [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0] # Constant predictions of the mean as per Q⁰ @test mean.(counterfactual_cache.predictions[1]) == mean.(counterfactual_cache.predictions[2]) == expected_value @test counterfactual_cache.signs == [1., -1.] - @test counterfactual_cache.covariates == [ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0] - ] + # Counterfactual covariates are now n×K matrices + @test size(counterfactual_cache.covariates[1]) == (7, 2) + @test size(counterfactual_cache.covariates[2]) == (7, 2) observed_cache = TMLE.initialize_observed_cache(weighted_fluctuation, X, y) - @test observed_cache[:ŷ] isa Vector{<:Normal} - @test observed_cache[:H] == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] # This is used to fit, so weight has been removed + @test observed_cache[:ŷ] isa Vector{<:Normal} + @test size(observed_cache[:H]) == (7, 2) + @test observed_cache[:H] * observed_cache[:signs] == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test observed_cache[:w] == [1.75, 3.5, 7., 1.75, 1.75, 3.5, 1.75] # weight is separate @test observed_cache[:y] isa Vector{Float64} ## Second fit the fluctuation w_machines, cache, w_report = MLJBase.fit(weighted_fluctuation, 0, X, y) ### Only one machine, only fitted the clever covariate mach = only(w_machines) - @test fitted_params(mach).features == [:covariate] + @test fitted_params(mach).features == [:H_1, :H_2] ### Report entries @test length(w_report.epsilons) == length(w_report.estimates) == length(w_report.gradients) == 1 @test w_report.epsilons[1][1] !== 0 @@ -70,7 +70,7 @@ using MLJGLMInterface unweighted_fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; weighted=false, tol=0, max_iter=3) ## First check the weight and covariates from the observed cache observed_cache = TMLE.initialize_observed_cache(unweighted_fluctuation, X, y) - @test observed_cache[:H] == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + @test observed_cache[:H] * observed_cache[:signs] == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] @test observed_cache[:w] == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ## Second fit the fluctuation logs = [(:info, "TMLE step: 1."), (:info, "TMLE step: 2."), (:info, "TMLE step: 3."), (:info, "Convergence criterion not reached.")] @@ -138,7 +138,8 @@ end @test length(counterfactual_cache.covariates) == 4 observed_cache = TMLE.initialize_observed_cache(fluctuation, X, y) @test observed_cache[:ŷ] isa UnivariateFiniteVector - @test isapprox(observed_cache[:H], [2.44, -3.26, -3.26, 2.44, 2.44, -6.12, 8.16], atol=0.1) + @test size(observed_cache[:H]) == (7, 4) + @test isapprox(observed_cache[:H] * observed_cache[:signs], [2.44, -3.26, -3.26, 2.44, 2.44, -6.12, 8.16], atol=0.1) @test observed_cache[:w] == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] @test observed_cache[:y] isa Vector{Float64} diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 4813ac10..f182b500 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -11,10 +11,10 @@ using JSON using YAML function regression_tests(tmle_result) - @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 + @test estimate(tmle_result) ≈ -0.184910 atol = 1e-6 l, u = confint(significance_test(tmle_result)) - @test l ≈ -0.279246 atol = 1e-6 - @test u ≈ -0.091821 atol = 1e-6 + @test l ≈ -0.278604 atol = 1e-6 + @test u ≈ -0.091215 atol = 1e-6 @test OneSampleZTest(tmle_result) isa OneSampleZTest end From d70cabf3265f3f1b28b85053f5577d2eb46cbf1a Mon Sep 17 00:00:00 2001 From: joshua-slaughter Date: Thu, 16 Apr 2026 17:05:22 +0100 Subject: [PATCH 4/7] update logging to account for deprecation --- test/estimators_and_estimates.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index aeedef52..420d45c9 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -46,7 +46,7 @@ end estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) # Fitting with no cache cache = Dict() - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) expected_features = collect(estimand.parents) @test conditional_density_estimate isa TMLE.MLConditionalDistribution @test fitted_params(conditional_density_estimate.machine).features == expected_features @@ -64,7 +64,7 @@ end @test_logs (:info, reuse_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) ## Changing the model leads to refit new_estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(fit_intercept=false)) - @test_logs (:info, fit_log) new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, fit_log) match_mode=:any new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) end @testset "Test MLConditionalDistributionEstimator: continuous outcome" begin From 6b872634b3df5100b784c6d6e79e54c53e62d97e Mon Sep 17 00:00:00 2001 From: joshua-slaughter Date: Thu, 16 Apr 2026 17:46:42 +0100 Subject: [PATCH 5/7] CCW-OSE; trying to supress log errors --- src/counterfactual_mean_based/estimators.jl | 43 +++++++++---- .../case_control_weighted_tmle.jl | 60 +++++++++++++++++++ test/estimators_and_estimates.jl | 12 ++-- 3 files changed, 97 insertions(+), 18 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index c5744c65..50b80796 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -182,10 +182,11 @@ mutable struct Ose <: Estimator resampling::Union{Nothing, ResamplingStrategy} ps_lowerbound::Union{Float64, Nothing} machine_cache::Bool + prevalence::Union{Nothing, Float64} end """ - Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false) + Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false, prevalence=nothing) Defines a One Step Estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -198,6 +199,8 @@ any valid `MLJ.ResamplingStrategy` will result in CV-OSE. - ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` (default) will result in a data adaptive definition as described in [Gruber et al. (2022)](https://pubmed.ncbi.nlm.nih.gov/35512316/). - machine_cache: Whether MLJ.machine created during estimation should cache data. +- prevalence: The prevalence of the outcome in the population. If provided, the estimator will use case-control +weighted estimation (CCW-OSE) to correct for biased sampling. # Run Argument @@ -216,22 +219,32 @@ ose = Ose() Ψ̂ₙ, cache = ose(Ψ, dataset) ``` """ -Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false) = - Ose(models, resampling, ps_lowerbound, machine_cache) +Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false, prevalence=nothing) = + Ose(models, resampling, ps_lowerbound, machine_cache, prevalence) function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1()) # Check the estimand against the dataset - check_treatment_levels(Ψ, dataset) + check_inputs(Ψ, dataset, ose.prevalence) # Make train-validation pairs train_validation_indices = get_train_validation_indices(ose.resampling, Ψ, dataset) # Initial fit of the SCM's relevant factors initial_factors = get_relevant_factors(Ψ) - nomissing_dataset = nomissing(dataset, variables(initial_factors)) - initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset; + # Get appropriate dataset (matched controls if prevalence is set) + fluctuation_dataset = get_fluctuation_dataset(dataset, initial_factors; + prevalence=ose.prevalence, + verbosity=verbosity + ) + initial_factors_dataset = choose_initial_dataset(dataset, fluctuation_dataset; train_validation_indices=train_validation_indices, - prevalence=nothing + prevalence=ose.prevalence + ) + # Compute prevalence weights for fitting Q + prevalence_weights = compute_prevalence_weights(ose.prevalence, initial_factors_dataset[!, initial_factors.outcome_mean.outcome]) + initial_factors_estimator = CMRelevantFactorsEstimator(; + models=ose.models, + train_validation_indices=train_validation_indices, + prevalence_weights=prevalence_weights ) - initial_factors_estimator = CMRelevantFactorsEstimator(;models=ose.models, train_validation_indices=train_validation_indices) initial_factors_estimate = initial_factors_estimator( initial_factors, initial_factors_dataset; @@ -240,19 +253,25 @@ function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), v acceleration=acceleration ) # Get propensity score truncation threshold - n = nrows(nomissing_dataset) + n = nrows(fluctuation_dataset) ps_lowerbound = ps_lower_bound(n, ose.ps_lowerbound) # Gradient and estimate - IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, fluctuation_dataset; + ps_lowerbound=ps_lowerbound, prevalence_weights=prevalence_weights) σ̂ = std(IC) n = size(IC, 1) verbosity >= 1 && @info "Done." return OSEstimate(Ψ, Ψ̂, σ̂, n, IC), cache end -function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=nothing) - IC, Ψ̂ = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound) +function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=nothing, prevalence_weights=nothing) + Q = factors.outcome_mean + G = factors.propensity_score + ctf_agg = counterfactual_aggregate(Ψ, Q, dataset) + gradient_Y_X = ∇YX(Ψ, Q, G, dataset; ps_lowerbound=ps_lowerbound) + y = float(dataset[!, Q.estimand.outcome]) + IC, Ψ̂ = gradient_and_estimate(ctf_agg, gradient_Y_X, y, prevalence_weights) IC_mean = mean(IC) IC .-= IC_mean return IC, Ψ̂ + IC_mean diff --git a/test/counterfactual_mean_based/case_control_weighted_tmle.jl b/test/counterfactual_mean_based/case_control_weighted_tmle.jl index f64ea7d8..ea90db0e 100644 --- a/test/counterfactual_mean_based/case_control_weighted_tmle.jl +++ b/test/counterfactual_mean_based/case_control_weighted_tmle.jl @@ -104,5 +104,65 @@ end @test mean(ccw_coverage) > 0.80 end +@testset "CCW-OSE bootstrapping test" begin + Random.seed!(42) + Npop = 2_000_000 + # Simulate population + W = rand(Bernoulli(0.5), Npop) + ηA = -0.2 .+ 0.8 .* W + pA = 1 ./ (1 .+ exp.(-ηA)) + A = rand.(Bernoulli.(pA)) + α, β, γ = -3, log(2), log(1.5) + pY = pY_given_A_W(A, W; α=α, β=β, γ=γ) + Y = rand.(Bernoulli.(pY)) + pop = DataFrame(W=W, A=A, Y=Y) + q₀ = mean(pop.Y .== 1) + + # Obtain the true risk difference (ATE) + true_rd = mean(pY_given_A_W(1, pop.W) .- pY_given_A_W(0, pop.W)) + + # Define ATE estimand + Ψ = ATE( + outcome=:Y, + treatment_values=(A=(case=true, control=false),), + treatment_confounders=(A=[:W],) + ) + # Standard OSE (no prevalence correction) + ose_std = Ose() + # CCW-OSE (with true prevalence) + ose_ccw = Ose(prevalence=q₀) + + # Draw a series of biased samples of size n_sample + n_sample = 10_000 + cc_prev = 0.2 + B = 30 + ccw_ose_results = Vector{Any}(undef, B) + std_ose_results = Vector{Any}(undef, B) + ccw_coverage = Vector{Bool}(undef, B) + std_coverage = Vector{Bool}(undef, B) + + for i in 1:B + sample = subsample_case_control(pop, n_sample, cc_prev, rng=Random.MersenneTwister(i)) + std_result, _ = ose_std(Ψ, sample; verbosity=0) + ccw_result, _ = ose_ccw(Ψ, sample; verbosity=0) + std_ose_results[i] = std_result.estimate + ccw_ose_results[i] = ccw_result.estimate + # Compare bias: CCW-OSE should be much less biased than standard OSE + std_bias = abs(std_result.estimate - true_rd) + ccw_bias = abs(ccw_result.estimate - true_rd) + @test ccw_bias < std_bias / 2 + # Retrieve coverage + lb, ub = confint(significance_test(ccw_result)) + ccw_coverage[i] = lb < true_rd < ub + lb, ub = confint(significance_test(std_result)) + std_coverage[i] = lb < true_rd < ub + end + # See if, on average, CCW-OSE outperforms standard OSE + @test (mean(ccw_ose_results) - true_rd) < (mean(std_ose_results) - true_rd) + # Test coverage is improved as well + @test mean(ccw_coverage) > mean(std_coverage) + @test mean(ccw_coverage) > 0.80 +end + end true \ No newline at end of file diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 420d45c9..238b181b 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -61,7 +61,7 @@ end # Check cache management ## Uses the cache instead of fitting new_estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) - @test_logs (:info, reuse_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, reuse_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) ## Changing the model leads to refit new_estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(fit_intercept=false)) @test_logs (:info, fit_log) match_mode=:any new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) @@ -72,7 +72,7 @@ end ## Probabilistic Model model = MLJGLMInterface.LinearRegressor() estimator = TMLE.MLConditionalDistributionEstimator(model) - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, continuous_dataset; cache=Dict(), verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, continuous_dataset; cache=Dict(), verbosity=verbosity) ŷ = MLJBase.predict(conditional_density_estimate, continuous_dataset) @test ŷ isa Vector{Normal{Float64}} μ̂ = TMLE.expected_value(conditional_density_estimate, continuous_dataset) @@ -134,7 +134,7 @@ end train_validation_indices ) cache = Dict() - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) @test conditional_density_estimate isa TMLE.SampleSplitMLConditionalDistribution expected_features = collect(estimand.parents) @test all(fitted_params(mach).features == expected_features for mach in conditional_density_estimate.machines) @@ -158,21 +158,21 @@ end LinearBinaryClassifier(), train_validation_indices ) - @test_logs (:info, reuse_log) estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) + @test_logs (:info, reuse_log) match_mode=:any estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) ## Changing the model leads to refit new_model = LinearBinaryClassifier(fit_intercept=false) new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( new_model, train_validation_indices ) - @test_logs (:info, fit_log) new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, fit_log) match_mode=:any new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) ## Changing the train/validation splits leads to refit train_validation_indices = Tuple(MLJBase.train_test_pairs(CV(nfolds=4), 1:n, binary_dataset)) new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( new_model, train_validation_indices ) - @test_logs (:info, fit_log) new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, fit_log) match_mode=:any new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) end @testset "Test SampleSplitMLConditionalDistributionEstimator: Continuous outcome" begin From 0130bebe978cf1840072fb02b865b7504eb01dd4 Mon Sep 17 00:00:00 2001 From: joshua-slaughter Date: Thu, 16 Apr 2026 18:02:27 +0100 Subject: [PATCH 6/7] more log supression for deprecation warnings --- .../counterfactual_mean_based/estimators_and_estimates.jl | 8 ++++---- test/counterfactual_mean_based/fluctuation.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/counterfactual_mean_based/estimators_and_estimates.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl index 78cd4eb1..32d2e0db 100644 --- a/test/counterfactual_mean_based/estimators_and_estimates.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -48,7 +48,7 @@ end (:info, TMLE.fit_string(Q)) ) cache = Dict() - η̂ₙ = @test_logs fit_log... η̂(η, dataset; cache=cache, verbosity=1) + η̂ₙ = @test_logs fit_log... match_mode=:any η̂(η, dataset; cache=cache, verbosity=1) # Test both sub estimands have been fitted @test η̂ₙ.outcome_mean isa TMLE.MLConditionalDistribution @test fitted_params(η̂ₙ.outcome_mean.machine) isa NamedTuple @@ -59,7 +59,7 @@ end # Both models unchanged, η̂ₙ is fully reused new_η̂ = TMLE.CMRelevantFactorsEstimator(models=models) full_reuse_log = (:info, TMLE.reuse_string(η)) - @test_logs full_reuse_log new_η̂(η, dataset; cache=cache, verbosity=1) + @test_logs full_reuse_log match_mode=:any new_η̂(η, dataset; cache=cache, verbosity=1) # Changing one model, only the other one is refitted models[:T₁] = LogisticClassifier(fit_intercept=false) new_η̂ = TMLE.CMRelevantFactorsEstimator(models=models) @@ -68,7 +68,7 @@ end (:info, TMLE.fit_string(G[1])), (:info, TMLE.reuse_string(Q)) ) - @test_logs partial_reuse_log... new_η̂(η, dataset; cache=cache, verbosity=1) + @test_logs partial_reuse_log... match_mode=:any new_η̂(η, dataset; cache=cache, verbosity=1) # Adding a resampling strategy cv_fit_log = ( @@ -78,7 +78,7 @@ end ) train_validation_indices = MLJBase.train_test_pairs(CV(nfolds=3), 1:nrow(dataset), dataset) resampled_η̂ = TMLE.CMRelevantFactorsEstimator(models=models, train_validation_indices=train_validation_indices) - η̂ₙ = @test_logs cv_fit_log... resampled_η̂(η, dataset; cache=cache, verbosity=1) + η̂ₙ = @test_logs cv_fit_log... match_mode=:any resampled_η̂(η, dataset; cache=cache, verbosity=1) @test length(η̂ₙ.outcome_mean.machines) == 3 ps_component = only(η̂ₙ.propensity_score.components) @test length(ps_component.machines) == 3 diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index e6910c9e..f5648ce5 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -74,7 +74,7 @@ using MLJGLMInterface @test observed_cache[:w] == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ## Second fit the fluctuation logs = [(:info, "TMLE step: 1."), (:info, "TMLE step: 2."), (:info, "TMLE step: 3."), (:info, "Convergence criterion not reached.")] - uw_machines, cache, uw_report = @test_logs logs... MLJBase.fit(unweighted_fluctuation, 1, X, y); + uw_machines, cache, uw_report = @test_logs logs... match_mode=:any MLJBase.fit(unweighted_fluctuation, 1, X, y); ### Only one machine, only fitted the clever covariate @test length(uw_machines) == 3 ### Report entries From d87bcc1ab7b6151cb7ef051b28f618df9de6271c Mon Sep 17 00:00:00 2001 From: joshua-slaughter Date: Sun, 19 Apr 2026 21:47:30 +0100 Subject: [PATCH 7/7] small bug fixes in tests; up version --- Project.toml | 2 +- test/counterfactual_mean_based/double_robustness_aie.jl | 4 ++-- test/utils.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index d801ef3d..1e5355ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TMLE" uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" authors = ["Olivier Labayle"] -version = "0.20.3" +version = "0.21.0" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/test/counterfactual_mean_based/double_robustness_aie.jl b/test/counterfactual_mean_based/double_robustness_aie.jl index 9fd62544..38703847 100644 --- a/test/counterfactual_mean_based/double_robustness_aie.jl +++ b/test/counterfactual_mean_based/double_robustness_aie.jl @@ -13,8 +13,8 @@ include(joinpath(TEST_DIR, "helper_fns.jl")) include(joinpath(TEST_DIR, "counterfactual_mean_based", "aie_simulations.jl")) cont_interacter = InteractionTransformer(order=2) |> LinearRegressor -# remove regularization, last example was misspecified for Q and G -cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=0) +# add some regularization +cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1e-6) @testset "Test Double Robustness AIE on binary_outcome_binary_treatment_pb" begin diff --git a/test/utils.jl b/test/utils.jl index b062eab3..96b00f84 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -200,7 +200,7 @@ end treatment_confounders=[:W] ) # The treatment levels correctly appear in the dataset - dataset = DataFrame(Y=rand(10), T=rand(0:1, 10), W=rand(10)) + dataset = DataFrame(Y=rand(10), T=repeat([0, 1], 5), W=rand(10)) @test TMLE.check_inputs(Ψ, dataset, nothing) isa Any # The treatment levels do not appear in the dataset dataset = DataFrame(Y=rand(10), T=rand(2:3, 10), W=rand(10))