Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.20.3"
version = "0.20.4"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
Expand Down
27 changes: 21 additions & 6 deletions src/counterfactual_mean_based/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mutable struct Tmle <: Estimator
tol::Union{Float64, Nothing}
max_iter::Int
machine_cache::Bool
prevalence::Union{Nothing, Float64}
prevalence::Union{Nothing, Float64, Dict{Symbol, Float64}}
function Tmle(
models,
resampling,
Expand Down Expand Up @@ -59,7 +59,7 @@ been show to be more robust to positivity violation in practice.
- tol (default: nothing): Convergence threshold for the TMLE algorithm iterations. If nothing (default), 1/(sample size) will be used. See also `max_iter`.
- max_iter (default: 1): Maximum number of iterations for the TMLE algorithm.
- machine_cache (default: false): Whether MLJ.machine created during estimation should cache data.
- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population.
- prevalence (default: nothing): If provided, the prevalence weights will be used to weight the observations to match the true prevalence of the source population. This can either be a single value to be uniformly applied, or a Dict that maps each trait to a prevalence value.

# Run Argument

Expand Down Expand Up @@ -88,6 +88,7 @@ function Tmle(;
machine_cache=false,
prevalence=nothing
)

Tmle(
models,
resampling,
Expand All @@ -100,24 +101,38 @@ function Tmle(;
)
end

function prevalence_for_estimand(Ψ, prevalence)
prevalence === nothing && return nothing

if prevalence isa Float64
return prevalence
elseif prevalence isa Dict{Symbol, Float64}
return prevalence[Ψ.outcome]
else
@error("Unsupported prevalence type: $(typeof(prevalence))")
end
end

function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1())
prevalence = prevalence_for_estimand(Ψ, tmle.prevalence)

# Check if the inputs are suitable for the specified estimand
check_inputs(Ψ, dataset, tmle.prevalence)
check_inputs(Ψ, dataset, prevalence)
# Make train-validation pairs
train_validation_indices = get_train_validation_indices(tmle.resampling, Ψ, dataset)
# Initial fit of the SCM's relevant factors
relevant_factors = get_relevant_factors(Ψ, collaborative_strategy=tmle.collaborative_strategy)
fluctuation_dataset = get_fluctuation_dataset(dataset, relevant_factors;
prevalence=tmle.prevalence,
prevalence=prevalence,
verbosity=verbosity
)

initial_factors_dataset = choose_initial_dataset(dataset, fluctuation_dataset;
train_validation_indices=train_validation_indices,
prevalence=tmle.prevalence
prevalence=prevalence
)

prevalence_weights = compute_prevalence_weights(tmle.prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome])
prevalence_weights = compute_prevalence_weights(prevalence, initial_factors_dataset[!, relevant_factors.outcome_mean.outcome])
initial_factors_estimator = CMRelevantFactorsEstimator(tmle.collaborative_strategy;
train_validation_indices=train_validation_indices,
models=tmle.models,
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
123 changes: 115 additions & 8 deletions test/counterfactual_mean_based/case_control_weighted_tmle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Distributions
using MLJBase
using MLJLinearModels
using Statistics
using CSV

# Helper: Draw a case-control sample with specified prevalence
function subsample_case_control(
Expand All @@ -20,22 +21,24 @@ function subsample_case_control(
)
n_case = round(Int, prevalence * n)
n_ctl = n - n_case
Ycol = pop[!, outcome_col]
cases = findall(Ycol .== 1)
controls = findall(Ycol .== 0)

ycol = pop[!, outcome_col]
cases = findall(ycol .== 1)
controls = findall(ycol .== 0)
if length(cases) < n_case
throw(ArgumentError("Not enough cases: have $(length(cases)), need $n_case"))
throw(ArgumentError("Not enough cases for $outcome_col: have $(length(cases)), need $n_case"))
end
if length(controls) < n_ctl
throw(ArgumentError("Not enough controls: have $(length(controls)), need $n_ctl"))
throw(ArgumentError("Not enough controls for $outcome_col: have $(length(controls)), need $n_ctl"))
end
ix_case = shuffle(rng, cases)[1:n_case]
ix_ctl = shuffle(rng, controls)[1:n_ctl]
ix = vcat(ix_case, ix_ctl)
ix = shuffle(rng, ix)
ix = shuffle(rng, vcat(ix_case, ix_ctl))

sub_pop = pop[ix, :]
sub_pop.A = categorical(Bool.(sub_pop.A))
sub_pop.Y = categorical(Bool.(sub_pop.Y))
sub_pop[!, outcome_col] = categorical(Bool.(sub_pop[!, outcome_col]))

return sub_pop
end

Expand All @@ -44,6 +47,21 @@ function pY_given_A_W(A, W; α=-3, β=log(2), γ=log(1.5))
return 1 ./ (1 .+ exp.(-ηY))
end

function make_population(Npop::Int)
W = rand(Bernoulli(0.5), Npop)
ηA = -0.2 .+ 0.8 .* W
pA = 1 ./ (1 .+ exp.(-ηA))
A = rand.(Bernoulli.(pA))

pY1 = pY_given_A_W(A, W; α=-3.0, β=log(2.0), γ=log(1.5))
pY2 = pY_given_A_W(A, W; α=-2.2, β=log(1.4), γ=log(1.8))

Y1 = rand.(Bernoulli.(pY1))
Y2 = rand.(Bernoulli.(pY2))

return DataFrame(W=W, A=A, Y1=Y1, Y2=Y2)
end

@testset "CCW-TMLE bootstrapping test" begin
Random.seed!(42)
Npop = 2_000_000
Expand Down Expand Up @@ -104,5 +122,94 @@ end
@test mean(ccw_coverage) > 0.80
end

@testset "Test multi-trait CCW run with prevalence dictionary" begin
Random.seed!(42)
pop = make_population(200_000)

# For running full model, copy pop
pop_copy = deepcopy(pop)
pop_copy.A = categorical(pop_copy.A)
pop_copy.Y1 = categorical(pop_copy.Y1)
pop_copy.Y2 = categorical(pop_copy.Y2)

# True prevalences computed from the population
prevalence_by_trait = Dict(
:Y1 => mean(pop.Y1),
:Y2 => mean(pop.Y2),
)

# Ground truth for each trait, using the parameters that generated them
trait_params = Dict(
:Y1 => (α = -3.0, β = log(2.0), γ = log(1.5)),
:Y2 => (α = -2.2, β = log(1.4), γ = log(1.8)),
)

true_rd_by_trait = Dict{Symbol, Float64}()
for trait in [:Y1, :Y2]
p = trait_params[trait]
true_rd_by_trait[trait] = mean(
pY_given_A_W(1, pop.W; α=p.α, β=p.β, γ=p.γ) .-
pY_given_A_W(0, pop.W; α=p.α, β=p.β, γ=p.γ)
)
end

traits = [:Y1, :Y2]
n_sample = 10_000
B = 10

for trait in traits
trait_prev = prevalence_by_trait[trait]
true_rd_trait = true_rd_by_trait[trait]

Ψ = ATE(
outcome = trait,
treatment_values = (A = (case = true, control = false),),
treatment_confounders = (A = [:W],)
)

tmle_std = Tmle(weighted=false)
tmle_ccw = Tmle(prevalence=trait_prev, weighted=false)
tmle_ccw_prev_dict = Tmle(prevalence=prevalence_by_trait, weighted=false)

# Check on full population: dict-based prevalence vs scalar prevalence
ccw_full_result, _ = tmle_ccw(Ψ, pop_copy; verbosity=0)
prev_dict_full_result, _ = tmle_ccw_prev_dict(Ψ, pop_copy; verbosity=0)
@test isapprox(ccw_full_result.estimate, prev_dict_full_result.estimate; atol=1e-3)

std_estimates = Float64[]
ccw_estimates = Float64[]
prev_dict_estimates = Float64[]

for b in 1:B
sample = subsample_case_control(
pop,
n_sample,
trait_prev;
outcome_col = trait,
rng = Random.MersenneTwister(1000 + b),
)

std_result, _ = tmle_std(Ψ, sample; verbosity=0)
ccw_result, _ = tmle_ccw(Ψ, sample; verbosity=0)

# Dictionary-based prevalence run
prev_dict_result, _ = tmle_ccw_prev_dict(Ψ, sample; verbosity=0)

push!(std_estimates, std_result.estimate)
push!(ccw_estimates, ccw_result.estimate)
push!(prev_dict_estimates, prev_dict_result.estimate)

@test isfinite(std_result.estimate)
@test isfinite(ccw_result.estimate)
@test isfinite(prev_dict_result.estimate)
end

# Dict-based prevalence and scalar prevalence should agree closely
@test isapprox(mean(prev_dict_estimates), mean(ccw_estimates); atol=1e-3)
# Bias should be reduced with correct prevalence specified
@test abs(mean(ccw_estimates) - true_rd_trait) < abs(mean(std_estimates) - true_rd_trait)
end
end

end
true
15 changes: 13 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,25 +209,36 @@ end

# Check with prevalence
prevalence = 0.1
prevalence_dict = Dict(:Y => 0.1)

Ψ = CM(
outcome = :Y,
treatment_values = (T=1,),
treatment_confounders = [:W]
)

## The outcome must be binary
dataset = DataFrame(
Y = categorical([1, 0, 1, 0, 1, 1, 2]),
Y = categorical([1, 0, 1, 0, 1, 1, 2]), # not binary
T = categorical([1, 1, 0, 1, 0, 2, 2]),
W = rand(7)
)

@test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence)

# Same check but using dict instead of scalar
@test_throws ArgumentError("Outcome column must be binary when prevalence is specified.") TMLE.check_inputs(Ψ, dataset, prevalence_dict)

## The number of controls must be larger than the number of cases
dataset = DataFrame(
Y = categorical([1, 0, 1, 0, 1, 1, 0]),
Y = categorical([1, 0, 1, 0, 1, 1, 0]), # more cases than controls
T = categorical([1, 1, 0, 1, 0, 2, 2]),
W = rand(7)
)

@test_throws ArgumentError("The dataset must contain more controls (0) than cases (1) when prevalence is provided.") TMLE.check_inputs(Ψ, dataset, prevalence)
# Same check with dict
@test_throws ArgumentError("The dataset must contain more controls (0) than cases (1) when prevalence is provided.") TMLE.check_inputs(Ψ, dataset, prevalence_dict)
end

@testset "Test get_fluctuation_dataset" begin
Expand Down
Loading