From d655ea9a95cdbab1e22b256d956350bb92428ca9 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Tue, 7 Apr 2026 15:07:33 +0100 Subject: [PATCH 01/19] add function for mapping prevalence to a trait for a given estimator, and accomodating prevalence-file --- src/counterfactual_mean_based/estimators.jl | 49 ++++++++++++++++----- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 467f647f..ae66e350 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -11,7 +11,8 @@ mutable struct Tmle <: Estimator tol::Union{Float64, Nothing} max_iter::Int machine_cache::Bool - prevalence::Union{Nothing, Float64} + prevalence::Union{Nothing, Float64, Dict{Symbol, Float64}} + prevalence_file::Union{Nothing, String} function Tmle( models, resampling, @@ -21,7 +22,8 @@ mutable struct Tmle <: Estimator tol, max_iter, machine_cache, - prevalence + prevalence, + prevalence_file ) if resampling === nothing && collaborative_strategy !== nothing @warn("Collaborative TMLE requires a resampling strategy but none was provided. Using the default resampling strategy.") @@ -35,7 +37,8 @@ mutable struct Tmle <: Estimator weighted, tol, max_iter, machine_cache, - prevalence + prevalence, + prevalence_file ) end end @@ -59,7 +62,8 @@ been show to be more robust to positivity violation in practice. - tol (default: nothing): Convergence threshold for the TMLE algorithm iterations. If nothing (default), 1/(sample size) will be used. See also `max_iter`. - max_iter (default: 1): Maximum number of iterations for the TMLE algorithm. - machine_cache (default: false): Whether MLJ.machine created during estimation should cache data. -- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population. +- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population. +- prevalence_file (default: nothing): A file including the prevalence for each trait can also be provided. This must be a TSV with the first column including the traits, and the second column including the prevalences. # Run Argument @@ -86,8 +90,12 @@ function Tmle(; tol=nothing, max_iter=1, machine_cache=false, - prevalence=nothing + prevalence=nothing, + prevalence_file=nothing ) + + prevalence_parsed = prevalence_file !== nothing ? load_prevalence_map(prevalence_file) : prevalence + Tmle( models, resampling, @@ -96,28 +104,49 @@ function Tmle(; weighted, tol, max_iter, machine_cache, - prevalence + prevalence_parsed, + prevalence_file ) end +function load_prevalence_map(prevalence_file::AbstractString) + df = CSV.read(prevalence_file, DataFrame; delim='\t') + isempty(df) && return nothing + return Dict(Symbol(df.trait[i]) => Float64(df.prevalence[i]) for i in 1:nrow(df)) +end + +function prevalence_for_estimand(Ψ, prevalence) + prevalence === nothing && return nothing + + if prevalence isa Float64 + return prevalence + elseif prevalence isa Dict{Symbol, Float64} + return prevalence[get_outcome(Ψ)] + else + @error("Unsupported prevalence type: $(typeof(prevalence))") + end +end + function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1()) + prevalence = prevalence_for_estimand(Ψ, tmle.prevalence) + # Check if the inputs are suitable for the specified estimand - check_inputs(Ψ, dataset, tmle.prevalence) + check_inputs(Ψ, dataset, prevalence) # Make train-validation pairs train_validation_indices = get_train_validation_indices(tmle.resampling, Ψ, dataset) # Initial fit of the SCM's relevant factors relevant_factors = get_relevant_factors(Ψ, collaborative_strategy=tmle.collaborative_strategy) fluctuation_dataset = get_fluctuation_dataset(dataset, relevant_factors; - prevalence=tmle.prevalence, + prevalence=prevalence, verbosity=verbosity ) initial_factors_dataset = choose_initial_dataset(dataset, fluctuation_dataset; train_validation_indices=train_validation_indices, - prevalence=tmle.prevalence + prevalence=prevalence ) - prevalence_weights = compute_prevalence_weights(tmle.prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome]) + prevalence_weights = compute_prevalence_weights(prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome]) initial_factors_estimator = CMRelevantFactorsEstimator(tmle.collaborative_strategy; train_validation_indices=train_validation_indices, models=tmle.models, From 61fba8d844dc13c8146ce5b2d5597184dc95edce Mon Sep 17 00:00:00 2001 From: roskamsh Date: Tue, 7 Apr 2026 15:08:01 +0100 Subject: [PATCH 02/19] update check to look for prevalence file --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index c54d4280..07ddd96b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,9 +253,9 @@ with_encoder(model; encoder=ContinuousEncoder(drop_last=true, one_hot_ordered_fa Evaluate if the dataset is suitable for the estimand Ψ. """ -function check_inputs(Ψ, dataset, prevalence) +function check_inputs(Ψ, dataset, prevalence, prevalence_file) check_treatment_levels(Ψ, dataset) - !isnothing(prevalence) && ccw_check(dataset, Ψ.outcome) + (!isnothing(prevalence) || !isnothing(prevalence_file)) && ccw_check(dataset, Ψ.outcome) end """ From bc4fe0c54ca931e1f7916299805f639e0975a5e0 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Tue, 7 Apr 2026 15:09:28 +0100 Subject: [PATCH 03/19] name columns appropriately --- src/counterfactual_mean_based/estimators.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index ae66e350..820a59f2 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -110,8 +110,9 @@ function Tmle(; end function load_prevalence_map(prevalence_file::AbstractString) - df = CSV.read(prevalence_file, DataFrame; delim='\t') + df = CSV.read(prevalence_file, DataFrame; delim='\t',header=F) isempty(df) && return nothing + rename!(df, [:trait, :prevalence]) return Dict(Symbol(df.trait[i]) => Float64(df.prevalence[i]) for i in 1:nrow(df)) end From 29850ab85ec12a09106775d5b41e48a436efc549 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:20:41 +0100 Subject: [PATCH 04/19] avoid using CSV ad TargeneCore --- src/counterfactual_mean_based/estimators.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 820a59f2..1a79abd1 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -110,10 +110,15 @@ function Tmle(; end function load_prevalence_map(prevalence_file::AbstractString) - df = CSV.read(prevalence_file, DataFrame; delim='\t',header=F) - isempty(df) && return nothing - rename!(df, [:trait, :prevalence]) - return Dict(Symbol(df.trait[i]) => Float64(df.prevalence[i]) for i in 1:nrow(df)) + lines = readlines(prevalence_file) + isempty(lines) && return nothing + + d = Dict{Symbol, Float64}() + for line in lines + trait, prev = split(line, '\t') + d[Symbol(trait)] = parse(Float64, prev) + end + return d end function prevalence_for_estimand(Ψ, prevalence) @@ -122,7 +127,7 @@ function prevalence_for_estimand(Ψ, prevalence) if prevalence isa Float64 return prevalence elseif prevalence isa Dict{Symbol, Float64} - return prevalence[get_outcome(Ψ)] + return prevalence[Ψ.outcome] else @error("Unsupported prevalence type: $(typeof(prevalence))") end From 739d4629fba0d043da710f4a7910502452fcf155 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:21:04 +0100 Subject: [PATCH 05/19] add prevalence_file default --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 07ddd96b..ccec32e4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,7 +253,7 @@ with_encoder(model; encoder=ContinuousEncoder(drop_last=true, one_hot_ordered_fa Evaluate if the dataset is suitable for the estimand Ψ. """ -function check_inputs(Ψ, dataset, prevalence, prevalence_file) +function check_inputs(Ψ, dataset, prevalence; prevalence_file=nothing) check_treatment_levels(Ψ, dataset) (!isnothing(prevalence) || !isnothing(prevalence_file)) && ccw_check(dataset, Ψ.outcome) end From 00e8d0b28f5a86268641edcd4e0e7145b5ae9304 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:22:34 +0100 Subject: [PATCH 06/19] add test for multi-trait CCW adjustment using prev file --- .../case_control_weighted_tmle.jl | 130 ++++++++++++++++-- 1 file changed, 122 insertions(+), 8 deletions(-) diff --git a/test/counterfactual_mean_based/case_control_weighted_tmle.jl b/test/counterfactual_mean_based/case_control_weighted_tmle.jl index f64ea7d8..242e14d7 100644 --- a/test/counterfactual_mean_based/case_control_weighted_tmle.jl +++ b/test/counterfactual_mean_based/case_control_weighted_tmle.jl @@ -9,6 +9,9 @@ using Distributions using MLJBase using MLJLinearModels using Statistics +using CSV + +DATADIR = joinpath(pkgdir(TMLE), "test", "data") # Helper: Draw a case-control sample with specified prevalence function subsample_case_control( @@ -20,22 +23,24 @@ function subsample_case_control( ) n_case = round(Int, prevalence * n) n_ctl = n - n_case - Ycol = pop[!, outcome_col] - cases = findall(Ycol .== 1) - controls = findall(Ycol .== 0) + + ycol = pop[!, outcome_col] + cases = findall(ycol .== 1) + controls = findall(ycol .== 0) if length(cases) < n_case - throw(ArgumentError("Not enough cases: have $(length(cases)), need $n_case")) + throw(ArgumentError("Not enough cases for $outcome_col: have $(length(cases)), need $n_case")) end if length(controls) < n_ctl - throw(ArgumentError("Not enough controls: have $(length(controls)), need $n_ctl")) + throw(ArgumentError("Not enough controls for $outcome_col: have $(length(controls)), need $n_ctl")) end ix_case = shuffle(rng, cases)[1:n_case] ix_ctl = shuffle(rng, controls)[1:n_ctl] - ix = vcat(ix_case, ix_ctl) - ix = shuffle(rng, ix) + ix = shuffle(rng, vcat(ix_case, ix_ctl)) + sub_pop = pop[ix, :] sub_pop.A = categorical(Bool.(sub_pop.A)) - sub_pop.Y = categorical(Bool.(sub_pop.Y)) + sub_pop[!, outcome_col] = categorical(Bool.(sub_pop[!, outcome_col])) + return sub_pop end @@ -44,6 +49,21 @@ function pY_given_A_W(A, W; α=-3, β=log(2), γ=log(1.5)) return 1 ./ (1 .+ exp.(-ηY)) end +function make_population(Npop::Int) + W = rand(Bernoulli(0.5), Npop) + ηA = -0.2 .+ 0.8 .* W + pA = 1 ./ (1 .+ exp.(-ηA)) + A = rand.(Bernoulli.(pA)) + + pY1 = pY_given_A_W(A, W; α=-3.0, β=log(2.0), γ=log(1.5)) + pY2 = pY_given_A_W(A, W; α=-2.2, β=log(1.4), γ=log(1.8)) + + Y1 = rand.(Bernoulli.(pY1)) + Y2 = rand.(Bernoulli.(pY2)) + + return DataFrame(W=W, A=A, Y1=Y1, Y2=Y2) +end + @testset "CCW-TMLE bootstrapping test" begin Random.seed!(42) Npop = 2_000_000 @@ -104,5 +124,99 @@ end @test mean(ccw_coverage) > 0.80 end +@testset "Test multi-trait CCW run with prevalence TSV file" begin + Random.seed!(42) + pop = make_population(200_000) + + # For running full model, copy pop + pop_copy = deepcopy(pop) + pop_copy.A = categorical(pop_copy.A) + pop_copy.Y1 = categorical(pop_copy.Y1) + pop_copy.Y2 = categorical(pop_copy.Y2) + + # Define prevalences + prevalence_file = joinpath(DATADIR, "prevalences.tsv") + prevalence_df = CSV.read(prevalence_file, DataFrame, header=false, delim="\t") + rename!(prevalence_df, [:trait, :prevalence]) + + prevalence_by_trait = Dict( + Symbol(row.trait) => Float64(row.prevalence) + for row in eachrow(prevalence_df) + ) + + # First, compute ground truth, with params used to generate Y1 and Y2 above + # This is used to compare true values in the loop below + trait_params = Dict( + :Y1 => (α = -3.0, β = log(2.0), γ = log(1.5)), + :Y2 => (α = -2.2, β = log(1.4), γ = log(1.8)), + ) + true_rd_by_trait = Dict{Symbol, Float64}() + for trait in [:Y1, :Y2] + p = trait_params[trait] + true_rd_by_trait[trait] = mean( + pY_given_A_W(1, pop.W; α=p.α, β=p.β, γ=p.γ) .- + pY_given_A_W(0, pop.W; α=p.α, β=p.β, γ=p.γ) + ) + end + + # Now run bootstrap across both traits + traits = [:Y1, :Y2] + n_sample = 10_000 + B = 10 + + for trait in traits + trait_prev = prevalence_by_trait[trait] + true_rd_trait = true_rd_by_trait[trait] + + Ψ = ATE( + outcome = trait, + treatment_values = (A = (case = true, control = false),), + treatment_confounders = (A = [:W],) + ) + + tmle_std = Tmle(weighted=false) + tmle_ccw = Tmle(prevalence=trait_prev, weighted=false) + tmle_ccw_prev_file = Tmle(prevalence_file=prevalence_file, weighted=false) + + # First, check on full population to see if prev_file and prev give the same result + ccw_full_result, _ = tmle_ccw(Ψ, pop_copy; verbosity=0) + prev_file_full_result, _ = tmle_ccw_prev_file(Ψ, pop_copy; verbosity=0) + @test isapprox(ccw_full_result.estimate, prev_file_full_result.estimate; atol=1e-3) + + std_estimates = Float64[] + ccw_estimates = Float64[] + prev_file_estimates = Float64[] + + for b in 1:B + sample = subsample_case_control( + pop, + n_sample, + trait_prev; + outcome_col = trait, + rng = Random.MersenneTwister(1000 + b), + ) + + std_result, _ = tmle_std(Ψ, sample; verbosity=0) + ccw_result, _ = tmle_ccw(Ψ, sample; verbosity=0) + + # This is the extra check: prevalence loaded from the TSV file directly + prev_file_result, _ = tmle_ccw_prev_file(Ψ, sample, verbosity=0) + + push!(std_estimates, std_result.estimate) + push!(ccw_estimates, ccw_result.estimate) + push!(prev_file_estimates, prev_file_result.estimate) + + @test isfinite(std_result.estimate) + @test isfinite(ccw_result.estimate) + @test isfinite(prev_file_result.estimate) + end + + # Check prev_file estimates and CCW are approx equal, and bias is reduced compard to std + @test isapprox(mean(prev_file_estimates), mean(ccw_estimates); atol=1e-3) + @test abs(mean(ccw_estimates) - true_rd_trait) < abs(mean(std_estimates) - true_rd_trait) + + end +end + end true \ No newline at end of file From 224316b4aed793d394906be1669461c895f8bbc5 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:22:57 +0100 Subject: [PATCH 07/19] add check_inputs for prevalence file --- test/utils.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/utils.jl b/test/utils.jl index b062eab3..4aae73b6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -8,6 +8,8 @@ using MLJLinearModels using MLJModels using DataFrames +TEST_DIR = joinpath(dirname(dirname(pathof(TMLE))), "test") + @testset "Test expected_value" begin n = 100 X = MLJBase.table(rand(n, 3)) @@ -221,6 +223,7 @@ end W = rand(7) ) @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence) + @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, nothing, joinpath(TEST_DIR,"data","prevalences.tsv")) ## The number of controls must be larger than the number of cases dataset = DataFrame( Y = categorical([1, 0, 1, 0, 1, 1, 0]), From 277dcc03b8d8da1bf6432f572bc12ce702ca10c5 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:23:11 +0100 Subject: [PATCH 08/19] add prevalences file --- test/data/prevalences.tsv | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 test/data/prevalences.tsv diff --git a/test/data/prevalences.tsv b/test/data/prevalences.tsv new file mode 100644 index 00000000..1a529305 --- /dev/null +++ b/test/data/prevalences.tsv @@ -0,0 +1,2 @@ +Y1 0.088 +Y2 0.016 From d9c44e3b943d6e5f325661d3c60f479246749fa1 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:33:40 +0100 Subject: [PATCH 09/19] update prevalence_file to be keyword arg --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 4aae73b6..38cff46f 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -223,7 +223,7 @@ end W = rand(7) ) @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence) - @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, nothing, joinpath(TEST_DIR,"data","prevalences.tsv")) + @test_throws ArgumentError("Outcome column must be binary when prevalence file is specified.") TMLE.check_inputs(Ψ, dataset, nothing; joinpath(TEST_DIR,"data","prevalences.tsv")) ## The number of controls must be larger than the number of cases dataset = DataFrame( Y = categorical([1, 0, 1, 0, 1, 1, 0]), From 4ad253efd4ad032c846fb943bb901fa4fbc4f711 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:39:47 +0100 Subject: [PATCH 10/19] add prev_file keyword to function call --- test/utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 38cff46f..c34d3a02 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -222,8 +222,9 @@ end T = categorical([1, 1, 0, 1, 0, 2, 2]), W = rand(7) ) + prevalence_file = joinpath(TEST_DIR,"data","prevalences.tsv") @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence) - @test_throws ArgumentError("Outcome column must be binary when prevalence file is specified.") TMLE.check_inputs(Ψ, dataset, nothing; joinpath(TEST_DIR,"data","prevalences.tsv")) + @test_throws ArgumentError("Outcome column must be binary when prevalence file is specified.") TMLE.check_inputs(Ψ, dataset, nothing; prevalence_file=prevalence_file) ## The number of controls must be larger than the number of cases dataset = DataFrame( Y = categorical([1, 0, 1, 0, 1, 1, 0]), From cf415464d4dfc7f18e0a1e6c905a84ed563af433 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:55:27 +0100 Subject: [PATCH 11/19] remove prevalence_file and just pass in either a single value or a Dict --- src/counterfactual_mean_based/estimators.jl | 30 ++++----------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 1a79abd1..590c0f34 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -12,7 +12,6 @@ mutable struct Tmle <: Estimator max_iter::Int machine_cache::Bool prevalence::Union{Nothing, Float64, Dict{Symbol, Float64}} - prevalence_file::Union{Nothing, String} function Tmle( models, resampling, @@ -22,8 +21,7 @@ mutable struct Tmle <: Estimator tol, max_iter, machine_cache, - prevalence, - prevalence_file + prevalence ) if resampling === nothing && collaborative_strategy !== nothing @warn("Collaborative TMLE requires a resampling strategy but none was provided. Using the default resampling strategy.") @@ -37,8 +35,7 @@ mutable struct Tmle <: Estimator weighted, tol, max_iter, machine_cache, - prevalence, - prevalence_file + prevalence ) end end @@ -62,8 +59,7 @@ been show to be more robust to positivity violation in practice. - tol (default: nothing): Convergence threshold for the TMLE algorithm iterations. If nothing (default), 1/(sample size) will be used. See also `max_iter`. - max_iter (default: 1): Maximum number of iterations for the TMLE algorithm. - machine_cache (default: false): Whether MLJ.machine created during estimation should cache data. -- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population. -- prevalence_file (default: nothing): A file including the prevalence for each trait can also be provided. This must be a TSV with the first column including the traits, and the second column including the prevalences. +- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population. This can either be a single value to be uniformly applied, or a Dict that maps each trait to a prevalence value. # Run Argument @@ -90,12 +86,9 @@ function Tmle(; tol=nothing, max_iter=1, machine_cache=false, - prevalence=nothing, - prevalence_file=nothing + prevalence=nothing ) - prevalence_parsed = prevalence_file !== nothing ? load_prevalence_map(prevalence_file) : prevalence - Tmle( models, resampling, @@ -104,23 +97,10 @@ function Tmle(; weighted, tol, max_iter, machine_cache, - prevalence_parsed, - prevalence_file + prevalence ) end -function load_prevalence_map(prevalence_file::AbstractString) - lines = readlines(prevalence_file) - isempty(lines) && return nothing - - d = Dict{Symbol, Float64}() - for line in lines - trait, prev = split(line, '\t') - d[Symbol(trait)] = parse(Float64, prev) - end - return d -end - function prevalence_for_estimand(Ψ, prevalence) prevalence === nothing && return nothing From d5a38e68d0381cd7c7ad4c9538f197427d983615 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:55:37 +0100 Subject: [PATCH 12/19] remove prevalence_file --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ccec32e4..c54d4280 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,9 +253,9 @@ with_encoder(model; encoder=ContinuousEncoder(drop_last=true, one_hot_ordered_fa Evaluate if the dataset is suitable for the estimand Ψ. """ -function check_inputs(Ψ, dataset, prevalence; prevalence_file=nothing) +function check_inputs(Ψ, dataset, prevalence) check_treatment_levels(Ψ, dataset) - (!isnothing(prevalence) || !isnothing(prevalence_file)) && ccw_check(dataset, Ψ.outcome) + !isnothing(prevalence) && ccw_check(dataset, Ψ.outcome) end """ From 79974411da295eb386706c29ea33fdfbf7e21757 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:55:52 +0100 Subject: [PATCH 13/19] update with prevalence_dict functionality --- .../case_control_weighted_tmle.jl | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/test/counterfactual_mean_based/case_control_weighted_tmle.jl b/test/counterfactual_mean_based/case_control_weighted_tmle.jl index 242e14d7..10e6439c 100644 --- a/test/counterfactual_mean_based/case_control_weighted_tmle.jl +++ b/test/counterfactual_mean_based/case_control_weighted_tmle.jl @@ -124,7 +124,7 @@ end @test mean(ccw_coverage) > 0.80 end -@testset "Test multi-trait CCW run with prevalence TSV file" begin +@testset "Test multi-trait CCW run with prevalence dictionary" begin Random.seed!(42) pop = make_population(200_000) @@ -134,22 +134,18 @@ end pop_copy.Y1 = categorical(pop_copy.Y1) pop_copy.Y2 = categorical(pop_copy.Y2) - # Define prevalences - prevalence_file = joinpath(DATADIR, "prevalences.tsv") - prevalence_df = CSV.read(prevalence_file, DataFrame, header=false, delim="\t") - rename!(prevalence_df, [:trait, :prevalence]) - + # True prevalences computed from the population prevalence_by_trait = Dict( - Symbol(row.trait) => Float64(row.prevalence) - for row in eachrow(prevalence_df) + :Y1 => mean(pop.Y1), + :Y2 => mean(pop.Y2), ) - # First, compute ground truth, with params used to generate Y1 and Y2 above - # This is used to compare true values in the loop below + # Ground truth for each trait, using the parameters that generated them trait_params = Dict( - :Y1 => (α = -3.0, β = log(2.0), γ = log(1.5)), - :Y2 => (α = -2.2, β = log(1.4), γ = log(1.8)), + :Y1 => (α = -3.0, β = log(2.0), γ = log(1.5)), + :Y2 => (α = -2.2, β = log(1.4), γ = log(1.8)), ) + true_rd_by_trait = Dict{Symbol, Float64}() for trait in [:Y1, :Y2] p = trait_params[trait] @@ -159,7 +155,6 @@ end ) end - # Now run bootstrap across both traits traits = [:Y1, :Y2] n_sample = 10_000 B = 10 @@ -176,16 +171,16 @@ end tmle_std = Tmle(weighted=false) tmle_ccw = Tmle(prevalence=trait_prev, weighted=false) - tmle_ccw_prev_file = Tmle(prevalence_file=prevalence_file, weighted=false) + tmle_ccw_prev_dict = Tmle(prevalence=prevalence_by_trait, weighted=false) - # First, check on full population to see if prev_file and prev give the same result - ccw_full_result, _ = tmle_ccw(Ψ, pop_copy; verbosity=0) - prev_file_full_result, _ = tmle_ccw_prev_file(Ψ, pop_copy; verbosity=0) - @test isapprox(ccw_full_result.estimate, prev_file_full_result.estimate; atol=1e-3) + # Check on full population: dict-based prevalence vs scalar prevalence + ccw_full_result, _ = tmle_ccw(Ψ, pop_copy; verbosity=0) + prev_dict_full_result, _ = tmle_ccw_prev_dict(Ψ, pop_copy; verbosity=0) + @test isapprox(ccw_full_result.estimate, prev_dict_full_result.estimate; atol=1e-3) std_estimates = Float64[] ccw_estimates = Float64[] - prev_file_estimates = Float64[] + prev_dict_estimates = Float64[] for b in 1:B sample = subsample_case_control( @@ -199,22 +194,23 @@ end std_result, _ = tmle_std(Ψ, sample; verbosity=0) ccw_result, _ = tmle_ccw(Ψ, sample; verbosity=0) - # This is the extra check: prevalence loaded from the TSV file directly - prev_file_result, _ = tmle_ccw_prev_file(Ψ, sample, verbosity=0) + # Dictionary-based prevalence run + prev_dict_result, _ = tmle_ccw_prev_dict(Ψ, sample; verbosity=0) push!(std_estimates, std_result.estimate) push!(ccw_estimates, ccw_result.estimate) - push!(prev_file_estimates, prev_file_result.estimate) + push!(prev_dict_estimates, prev_dict_result.estimate) @test isfinite(std_result.estimate) @test isfinite(ccw_result.estimate) - @test isfinite(prev_file_result.estimate) + @test isfinite(prev_dict_result.estimate) end - # Check prev_file estimates and CCW are approx equal, and bias is reduced compard to std - @test isapprox(mean(prev_file_estimates), mean(ccw_estimates); atol=1e-3) - @test abs(mean(ccw_estimates) - true_rd_trait) < abs(mean(std_estimates) - true_rd_trait) + # Dict-based prevalence and scalar prevalence should agree closely + @test isapprox(mean(prev_dict_estimates), mean(ccw_estimates); atol=1e-3) + # CCW should be closer to the truth than standard TMLE + @test abs(mean(ccw_estimates) - true_rd_trait) < abs(mean(std_estimates) - true_rd_trait) end end From b212774367b799029a10799c24be742df2f9f54a Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 08:56:09 +0100 Subject: [PATCH 14/19] update check_inputs --- test/utils.jl | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index c34d3a02..1c123075 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -211,27 +211,36 @@ end # Check with prevalence prevalence = 0.1 + prevalence_dict = Dict(:Y => 0.1) + Ψ = CM( outcome = :Y, treatment_values = (T=1,), treatment_confounders = [:W] ) + ## The outcome must be binary dataset = DataFrame( - Y = categorical([1, 0, 1, 0, 1, 1, 2]), + Y = categorical([1, 0, 1, 0, 1, 1, 2]), # not binary T = categorical([1, 1, 0, 1, 0, 2, 2]), W = rand(7) ) - prevalence_file = joinpath(TEST_DIR,"data","prevalences.tsv") + @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence) - @test_throws ArgumentError("Outcome column must be binary when prevalence file is specified.") TMLE.check_inputs(Ψ, dataset, nothing; prevalence_file=prevalence_file) + + # Same check but using dict instead of scalar + @test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence_dict) + ## The number of controls must be larger than the number of cases dataset = DataFrame( - Y = categorical([1, 0, 1, 0, 1, 1, 0]), + Y = categorical([1, 0, 1, 0, 1, 1, 0]), # more cases than controls T = categorical([1, 1, 0, 1, 0, 2, 2]), W = rand(7) ) + @test_throws ArgumentError("The dataset must contain more controls (0) than cases (1) when prevalence is provided.") TMLE.check_inputs(Ψ, dataset, prevalence) + # Same check with dict + @test_throws ArgumentError("The dataset must contain more controls (0) than cases (1) when prevalence is provided.") TMLE.check_inputs(Ψ, dataset, prevalence_dict) end @testset "Test get_fluctuation_dataset" begin From 67af0c9e03e07dcb06c2bb16b25502de0d19f037 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 09:01:45 +0100 Subject: [PATCH 15/19] remove old file --- test/data/prevalences.tsv | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 test/data/prevalences.tsv diff --git a/test/data/prevalences.tsv b/test/data/prevalences.tsv deleted file mode 100644 index 1a529305..00000000 --- a/test/data/prevalences.tsv +++ /dev/null @@ -1,2 +0,0 @@ -Y1 0.088 -Y2 0.016 From 1a848a0fdc63871819503168a8c22d1146872fef Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 09:02:02 +0100 Subject: [PATCH 16/19] remove datadir as not needed now --- test/counterfactual_mean_based/case_control_weighted_tmle.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/counterfactual_mean_based/case_control_weighted_tmle.jl b/test/counterfactual_mean_based/case_control_weighted_tmle.jl index 10e6439c..adc0d29c 100644 --- a/test/counterfactual_mean_based/case_control_weighted_tmle.jl +++ b/test/counterfactual_mean_based/case_control_weighted_tmle.jl @@ -11,8 +11,6 @@ using MLJLinearModels using Statistics using CSV -DATADIR = joinpath(pkgdir(TMLE), "test", "data") - # Helper: Draw a case-control sample with specified prevalence function subsample_case_control( pop::DataFrame, @@ -208,8 +206,7 @@ end # Dict-based prevalence and scalar prevalence should agree closely @test isapprox(mean(prev_dict_estimates), mean(ccw_estimates); atol=1e-3) - - # CCW should be closer to the truth than standard TMLE + # Bias should be reduced with correct prevalence specified @test abs(mean(ccw_estimates) - true_rd_trait) < abs(mean(std_estimates) - true_rd_trait) end end From 6e23b39ba8c1702aed9595490eedec8f8ce9d0fe Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 09:02:14 +0100 Subject: [PATCH 17/19] remove testdir as not needed now --- test/utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 1c123075..55f5147e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -8,8 +8,6 @@ using MLJLinearModels using MLJModels using DataFrames -TEST_DIR = joinpath(dirname(dirname(pathof(TMLE))), "test") - @testset "Test expected_value" begin n = 100 X = MLJBase.table(rand(n, 3)) From 9b929b1c164be60890189ac4a9688d37cf516fa7 Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 09:06:17 +0100 Subject: [PATCH 18/19] up TMLE --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index 10c8a318..0461755b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,6 +21,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 0dce2109c011b85b15ada4c241d97aecf46c1b5a Mon Sep 17 00:00:00 2001 From: roskamsh Date: Wed, 8 Apr 2026 09:30:56 +0100 Subject: [PATCH 19/19] remove updates --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d801ef3d..3fdb3e65 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TMLE" uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" authors = ["Olivier Labayle"] -version = "0.20.3" +version = "0.20.4" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"