diff --git a/Project.toml b/Project.toml index d801ef3d..3fdb3e65 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TMLE" uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" authors = ["Olivier Labayle"] -version = "0.20.3" +version = "0.20.4" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 467f647f..590c0f34 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -11,7 +11,7 @@ mutable struct Tmle <: Estimator tol::Union{Float64, Nothing} max_iter::Int machine_cache::Bool - prevalence::Union{Nothing, Float64} + prevalence::Union{Nothing, Float64, Dict{Symbol, Float64}} function Tmle( models, resampling, @@ -59,7 +59,7 @@ been show to be more robust to positivity violation in practice. - tol (default: nothing): Convergence threshold for the TMLE algorithm iterations. If nothing (default), 1/(sample size) will be used. See also `max_iter`. - max_iter (default: 1): Maximum number of iterations for the TMLE algorithm. - machine_cache (default: false): Whether MLJ.machine created during estimation should cache data. -- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population. +- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population. This can either be a single value to be uniformly applied, or a Dict that maps each trait to a prevalence value. # Run Argument @@ -88,6 +88,7 @@ function Tmle(; machine_cache=false, prevalence=nothing ) + Tmle( models, resampling, @@ -100,24 +101,38 @@ function Tmle(; ) end +function prevalence_for_estimand(Ψ, prevalence) + prevalence === nothing && return nothing + + if prevalence isa Float64 + return prevalence + elseif prevalence isa Dict{Symbol, Float64} + return prevalence[Ψ.outcome] + else + @error("Unsupported prevalence type: $(typeof(prevalence))") + end +end + function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1()) + prevalence = prevalence_for_estimand(Ψ, tmle.prevalence) + # Check if the inputs are suitable for the specified estimand - check_inputs(Ψ, dataset, tmle.prevalence) + check_inputs(Ψ, dataset, prevalence) # Make train-validation pairs train_validation_indices = get_train_validation_indices(tmle.resampling, Ψ, dataset) # Initial fit of the SCM's relevant factors relevant_factors = get_relevant_factors(Ψ, collaborative_strategy=tmle.collaborative_strategy) fluctuation_dataset = get_fluctuation_dataset(dataset, relevant_factors; - prevalence=tmle.prevalence, + prevalence=prevalence, verbosity=verbosity ) initial_factors_dataset = choose_initial_dataset(dataset, fluctuation_dataset; train_validation_indices=train_validation_indices, - prevalence=tmle.prevalence + prevalence=prevalence ) - prevalence_weights = compute_prevalence_weights(tmle.prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome]) + prevalence_weights = compute_prevalence_weights(prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome]) initial_factors_estimator = CMRelevantFactorsEstimator(tmle.collaborative_strategy; train_validation_indices=train_validation_indices, models=tmle.models, diff --git a/test/Project.toml b/test/Project.toml index 10c8a318..0461755b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,6 +21,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/counterfactual_mean_based/case_control_weighted_tmle.jl b/test/counterfactual_mean_based/case_control_weighted_tmle.jl index f64ea7d8..adc0d29c 100644 --- a/test/counterfactual_mean_based/case_control_weighted_tmle.jl +++ b/test/counterfactual_mean_based/case_control_weighted_tmle.jl @@ -9,6 +9,7 @@ using Distributions using MLJBase using MLJLinearModels using Statistics +using CSV # Helper: Draw a case-control sample with specified prevalence function subsample_case_control( @@ -20,22 +21,24 @@ function subsample_case_control( ) n_case = round(Int, prevalence * n) n_ctl = n - n_case - Ycol = pop[!, outcome_col] - cases = findall(Ycol .== 1) - controls = findall(Ycol .== 0) + + ycol = pop[!, outcome_col] + cases = findall(ycol .== 1) + controls = findall(ycol .== 0) if length(cases) < n_case - throw(ArgumentError("Not enough cases: have $(length(cases)), need $n_case")) + throw(ArgumentError("Not enough cases for $outcome_col: have $(length(cases)), need $n_case")) end if length(controls) < n_ctl - throw(ArgumentError("Not enough controls: have $(length(controls)), need $n_ctl")) + throw(ArgumentError("Not enough controls for $outcome_col: have $(length(controls)), need $n_ctl")) end ix_case = shuffle(rng, cases)[1:n_case] ix_ctl = shuffle(rng, controls)[1:n_ctl] - ix = vcat(ix_case, ix_ctl) - ix = shuffle(rng, ix) + ix = shuffle(rng, vcat(ix_case, ix_ctl)) + sub_pop = pop[ix, :] sub_pop.A = categorical(Bool.(sub_pop.A)) - sub_pop.Y = categorical(Bool.(sub_pop.Y)) + sub_pop[!, outcome_col] = categorical(Bool.(sub_pop[!, outcome_col])) + return sub_pop end @@ -44,6 +47,21 @@ function pY_given_A_W(A, W; α=-3, β=log(2), γ=log(1.5)) return 1 ./ (1 .+ exp.(-ηY)) end +function make_population(Npop::Int) + W = rand(Bernoulli(0.5), Npop) + ηA = -0.2 .+ 0.8 .* W + pA = 1 ./ (1 .+ exp.(-ηA)) + A = rand.(Bernoulli.(pA)) + + pY1 = pY_given_A_W(A, W; α=-3.0, β=log(2.0), γ=log(1.5)) + pY2 = pY_given_A_W(A, W; α=-2.2, β=log(1.4), γ=log(1.8)) + + Y1 = rand.(Bernoulli.(pY1)) + Y2 = rand.(Bernoulli.(pY2)) + + return DataFrame(W=W, A=A, Y1=Y1, Y2=Y2) +end + @testset "CCW-TMLE bootstrapping test" begin Random.seed!(42) Npop = 2_000_000 @@ -104,5 +122,94 @@ end @test mean(ccw_coverage) > 0.80 end +@testset "Test multi-trait CCW run with prevalence dictionary" begin + Random.seed!(42) + pop = make_population(200_000) + + # For running full model, copy pop + pop_copy = deepcopy(pop) + pop_copy.A = categorical(pop_copy.A) + pop_copy.Y1 = categorical(pop_copy.Y1) + pop_copy.Y2 = categorical(pop_copy.Y2) + + # True prevalences computed from the population + prevalence_by_trait = Dict( + :Y1 => mean(pop.Y1), + :Y2 => mean(pop.Y2), + ) + + # Ground truth for each trait, using the parameters that generated them + trait_params = Dict( + :Y1 => (α = -3.0, β = log(2.0), γ = log(1.5)), + :Y2 => (α = -2.2, β = log(1.4), γ = log(1.8)), + ) + + true_rd_by_trait = Dict{Symbol, Float64}() + for trait in [:Y1, :Y2] + p = trait_params[trait] + true_rd_by_trait[trait] = mean( + pY_given_A_W(1, pop.W; α=p.α, β=p.β, γ=p.γ) .- + pY_given_A_W(0, pop.W; α=p.α, β=p.β, γ=p.γ) + ) + end + + traits = [:Y1, :Y2] + n_sample = 10_000 + B = 10 + + for trait in traits + trait_prev = prevalence_by_trait[trait] + true_rd_trait = true_rd_by_trait[trait] + + Ψ = ATE( + outcome = trait, + treatment_values = (A = (case = true, control = false),), + treatment_confounders = (A = [:W],) + ) + + tmle_std = Tmle(weighted=false) + tmle_ccw = Tmle(prevalence=trait_prev, weighted=false) + tmle_ccw_prev_dict = Tmle(prevalence=prevalence_by_trait, weighted=false) + + # Check on full population: dict-based prevalence vs scalar prevalence + ccw_full_result, _ = tmle_ccw(Ψ, pop_copy; verbosity=0) + prev_dict_full_result, _ = tmle_ccw_prev_dict(Ψ, pop_copy; verbosity=0) + @test isapprox(ccw_full_result.estimate, prev_dict_full_result.estimate; atol=1e-3) + + std_estimates = Float64[] + ccw_estimates = Float64[] + prev_dict_estimates = Float64[] + + for b in 1:B + sample = subsample_case_control( + pop, + n_sample, + trait_prev; + outcome_col = trait, + rng = Random.MersenneTwister(1000 + b), + ) + + std_result, _ = tmle_std(Ψ, sample; verbosity=0) + ccw_result, _ = tmle_ccw(Ψ, sample; verbosity=0) + + # Dictionary-based prevalence run + prev_dict_result, _ = tmle_ccw_prev_dict(Ψ, sample; verbosity=0) + + push!(std_estimates, std_result.estimate) + push!(ccw_estimates, ccw_result.estimate) + push!(prev_dict_estimates, prev_dict_result.estimate) + + @test isfinite(std_result.estimate) + @test isfinite(ccw_result.estimate) + @test isfinite(prev_dict_result.estimate) + end + + # Dict-based prevalence and scalar prevalence should agree closely + @test isapprox(mean(prev_dict_estimates), mean(ccw_estimates); atol=1e-3) + # Bias should be reduced with correct prevalence specified + @test abs(mean(ccw_estimates) - true_rd_trait) < abs(mean(std_estimates) - true_rd_trait) + end +end + end true \ No newline at end of file diff --git a/test/utils.jl b/test/utils.jl index b062eab3..55f5147e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -209,25 +209,36 @@ end # Check with prevalence prevalence = 0.1 + prevalence_dict = Dict(:Y => 0.1) + Ψ = CM( outcome = :Y, treatment_values = (T=1,), treatment_confounders = [:W] ) + ## The outcome must be binary dataset = DataFrame( - Y = categorical([1, 0, 1, 0, 1, 1, 2]), + Y = categorical([1, 0, 1, 0, 1, 1, 2]), # not binary T = categorical([1, 1, 0, 1, 0, 2, 2]), W = rand(7) ) + @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence) + + # Same check but using dict instead of scalar + @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence_dict) + ## The number of controls must be larger than the number of cases dataset = DataFrame( - Y = categorical([1, 0, 1, 0, 1, 1, 0]), + Y = categorical([1, 0, 1, 0, 1, 1, 0]), # more cases than controls T = categorical([1, 1, 0, 1, 0, 2, 2]), W = rand(7) ) + @test_throws ArgumentError("The dataset must contain more controls (0) than cases (1) when prevalence is provided.") TMLE.check_inputs(Ψ, dataset, prevalence) + # Same check with dict + @test_throws ArgumentError("The dataset must contain more controls (0) than cases (1) when prevalence is provided.") TMLE.check_inputs(Ψ, dataset, prevalence_dict) end @testset "Test get_fluctuation_dataset" begin