diff --git a/Project.toml b/Project.toml index d801ef3d..1e5355ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TMLE" uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" authors = ["Olivier Labayle"] -version = "0.20.3" +version = "0.21.0" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/src/counterfactual_mean_based/clever_covariate.jl b/src/counterfactual_mean_based/clever_covariate.jl index 13e9880a..04dff49f 100644 --- a/src/counterfactual_mean_based/clever_covariate.jl +++ b/src/counterfactual_mean_based/clever_covariate.jl @@ -1,12 +1,17 @@ """ - data_adaptive_ps_lower_bound(Ψ::StatisticalCMCompositeEstimand) + data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) -This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false) -but the study does not show strictly better behaviour of the strategy so not a default for now. +Data-adaptive propensity score truncation level from Gruber et al. (2022): +"Data-Adaptive Selection of the Propensity Score Truncation Level for +Inverse-Probability–Weighted and Targeted Maximum Likelihood Estimators +of Marginal Point Treatment Effects" (doi:10.1093/aje/kwac087). + +This is the default when `ps_lowerbound=nothing`. Here a maximum lower bound +is applied to prevent extreme truncation in small samples. """ data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = - min(5 / (√(n)*log(n/5)), max_lb) + min(5 / (√(n)*log(n)), max_lb) ps_lower_bound(n::Int, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(n; max_lb=max_lb) ps_lower_bound(n::Int, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) @@ -18,12 +23,14 @@ function truncate!(v::AbstractVector, ps_lowerbound::AbstractFloat) end end -function balancing_weights(G, dataset; ps_lowerbound=1e-8) - jointlikelihood = ones(nrows(dataset)) +function balancing_weights(G, dataset; ps_lowerbound=nothing) + n = nrows(dataset) + jointlikelihood = ones(n) for Gᵢ ∈ G.components jointlikelihood .*= likelihood(Gᵢ, dataset) end - truncate!(jointlikelihood, ps_lowerbound) + actual_lowerbound = ps_lower_bound(n, ps_lowerbound) + truncate!(jointlikelihood, actual_lowerbound) return 1. ./ jointlikelihood end @@ -32,39 +39,44 @@ end Ψ::StatisticalCMCompositeEstimand, Gs::Tuple{Vararg{ConditionalDistributionEstimate}}, dataset; - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted_fluctuation=false ) -Computes the clever covariate and weights that are used to fluctuate the initial Q. +Computes the clever covariate matrix, weights, and signs used to fluctuate the initial Q. + +Returns a tuple `(H, w, signs)` where: +- `H` is an `n × K` matrix, one column per counterfactual in `indicator_fns(Ψ)`. +- `w` is a length-`n` weight vector. +- `signs` is a length-`K` vector of signs from `indicator_fns(Ψ)`. if `weighted_fluctuation = false`: -- ``clever_covariate(t, w) = \\frac{SpecialIndicator(t)}{p(t|w)}`` -- ``weight(t, w) = 1`` +- ``H_{k}(t, w) = \\frac{I_k(t)}{p(t|w)}`` +- ``w(t, w) = 1`` if `weighted_fluctuation = true`: -- ``clever_covariate(t, w) = SpecialIndicator(t)`` -- ``weight(t, w) = \\frac{1}{p(t|w)}`` +- ``H_{k}(t, w) = I_k(t)`` +- ``w(t, w) = \\frac{1}{p(t|w)}`` -where SpecialIndicator(t) is defined in `indicator_fns`. +where ``I_k(t)`` is the unsigned indicator for the k-th counterfactual. """ function clever_covariate_and_weights( Ψ::StatisticalCMCompositeEstimand, G, dataset; - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted_fluctuation=false ) - # Compute the indicator values + # Compute the indicator matrix (n×K) and signs (K,) T = selectcols(dataset, (p.estimand.outcome for p in G.components)) - indic_vals = indicator_values(indicator_fns(Ψ), T) + indic_mat, signs = indicator_matrix(indicator_fns(Ψ), T) weights = balancing_weights(G, dataset; ps_lowerbound=ps_lowerbound) if weighted_fluctuation - return indic_vals, weights + return indic_mat, weights, signs end # Vanilla unweighted fluctuation - indic_vals .*= weights - return indic_vals, ones(size(weights, 1)) + indic_mat .*= weights + return indic_mat, ones(size(weights, 1)), signs end \ No newline at end of file diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 467f647f..50b80796 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -41,7 +41,7 @@ mutable struct Tmle <: Estimator end """ - Tmle(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing, machine_cache=false) + Tmle(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, weighted=false, tol=nothing, machine_cache=false) Defines a TMLE estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -52,8 +52,9 @@ function that can be applied to estimate estimands for a dataset. - collaborative_strategy (default: nothing): A collaborative strategy to use for the estimation. Then the resampling strategy is used to evaluate the candidates. - resampling (default: `default_resampling(collaborative_strategy)`): Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla TMLE while any valid `MLJ.ResamplingStrategy` will result in CV-TMLE. -- ps_lowerbound (default: 1e-8): Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will -result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/). +- ps_lowerbound (default: nothing): Lowerbound for the propensity score to avoid division by 0. The default `nothing` +uses data-adaptive truncation as described in [Gruber et al. (2022)](https://pubmed.ncbi.nlm.nih.gov/35512316/): `5/(√n * log(n/5))`. +A fixed value can also be provided. - weighted (default: false): Whether the fluctuation model is a classig GLM or a weighted version. The weighted fluctuation has been show to be more robust to positivity violation in practice. - tol (default: nothing): Convergence threshold for the TMLE algorithm iterations. If nothing (default), 1/(sample size) will be used. See also `max_iter`. @@ -81,7 +82,7 @@ function Tmle(; models=default_models(), collaborative_strategy=nothing, resampling=default_resampling(collaborative_strategy), - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted=true, tol=nothing, max_iter=1, @@ -103,6 +104,8 @@ end function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1()) # Check if the inputs are suitable for the specified estimand check_inputs(Ψ, dataset, tmle.prevalence) + # Reset collaborative strategy state before building relevant factors + tmle.collaborative_strategy !== nothing && initialise!(tmle.collaborative_strategy, Ψ) # Make train-validation pairs train_validation_indices = get_train_validation_indices(tmle.resampling, Ψ, dataset) # Initial fit of the SCM's relevant factors @@ -167,7 +170,7 @@ function (tmle::Tmle)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), return TMLEstimate(Ψ, Ψ̂, σ̂, n, IC), cache end -gradient_and_estimate(::Tmle, Ψ, factors, dataset; ps_lowerbound=1e-8) = +gradient_and_estimate(::Tmle, Ψ, factors, dataset; ps_lowerbound=nothing) = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound) ##################################################################### @@ -179,10 +182,11 @@ mutable struct Ose <: Estimator resampling::Union{Nothing, ResamplingStrategy} ps_lowerbound::Union{Float64, Nothing} machine_cache::Bool + prevalence::Union{Nothing, Float64} end """ - Ose(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) + Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false, prevalence=nothing) Defines a One Step Estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -192,9 +196,11 @@ function that can be applied to estimate estimands for a dataset. - models: A Dict(variable => model, ...) where the `variables` are the outcome variables modeled by the `models`. - resampling: Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla estimation while any valid `MLJ.ResamplingStrategy` will result in CV-OSE. -- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will -result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/). +- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` (default) will +result in a data adaptive definition as described in [Gruber et al. (2022)](https://pubmed.ncbi.nlm.nih.gov/35512316/). - machine_cache: Whether MLJ.machine created during estimation should cache data. +- prevalence: The prevalence of the outcome in the population. If provided, the estimator will use case-control +weighted estimation (CCW-OSE) to correct for biased sampling. # Run Argument @@ -213,22 +219,32 @@ ose = Ose() Ψ̂ₙ, cache = ose(Ψ, dataset) ``` """ -Ose(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) = - Ose(models, resampling, ps_lowerbound, machine_cache) +Ose(;models=default_models(), resampling=nothing, ps_lowerbound=nothing, machine_cache=false, prevalence=nothing) = + Ose(models, resampling, ps_lowerbound, machine_cache, prevalence) function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1, acceleration=CPU1()) # Check the estimand against the dataset - check_treatment_levels(Ψ, dataset) + check_inputs(Ψ, dataset, ose.prevalence) # Make train-validation pairs train_validation_indices = get_train_validation_indices(ose.resampling, Ψ, dataset) # Initial fit of the SCM's relevant factors initial_factors = get_relevant_factors(Ψ) - nomissing_dataset = nomissing(dataset, variables(initial_factors)) - initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset; + # Get appropriate dataset (matched controls if prevalence is set) + fluctuation_dataset = get_fluctuation_dataset(dataset, initial_factors; + prevalence=ose.prevalence, + verbosity=verbosity + ) + initial_factors_dataset = choose_initial_dataset(dataset, fluctuation_dataset; train_validation_indices=train_validation_indices, - prevalence=nothing + prevalence=ose.prevalence + ) + # Compute prevalence weights for fitting Q + prevalence_weights = compute_prevalence_weights(ose.prevalence, initial_factors_dataset[!, initial_factors.outcome_mean.outcome]) + initial_factors_estimator = CMRelevantFactorsEstimator(; + models=ose.models, + train_validation_indices=train_validation_indices, + prevalence_weights=prevalence_weights ) - initial_factors_estimator = CMRelevantFactorsEstimator(;models=ose.models, train_validation_indices=train_validation_indices) initial_factors_estimate = initial_factors_estimator( initial_factors, initial_factors_dataset; @@ -237,19 +253,25 @@ function (ose::Ose)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), v acceleration=acceleration ) # Get propensity score truncation threshold - n = nrows(nomissing_dataset) + n = nrows(fluctuation_dataset) ps_lowerbound = ps_lower_bound(n, ose.ps_lowerbound) # Gradient and estimate - IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, fluctuation_dataset; + ps_lowerbound=ps_lowerbound, prevalence_weights=prevalence_weights) σ̂ = std(IC) n = size(IC, 1) verbosity >= 1 && @info "Done." return OSEstimate(Ψ, Ψ̂, σ̂, n, IC), cache end -function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=1e-8) - IC, Ψ̂ = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound) +function gradient_and_estimate(::Ose, Ψ, factors, dataset; ps_lowerbound=nothing, prevalence_weights=nothing) + Q = factors.outcome_mean + G = factors.propensity_score + ctf_agg = counterfactual_aggregate(Ψ, Q, dataset) + gradient_Y_X = ∇YX(Ψ, Q, G, dataset; ps_lowerbound=ps_lowerbound) + y = float(dataset[!, Q.estimand.outcome]) + IC, Ψ̂ = gradient_and_estimate(ctf_agg, gradient_Y_X, y, prevalence_weights) IC_mean = mean(IC) IC .-= IC_mean return IC, Ψ̂ + IC_mean diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index dd7c2b51..2d489b63 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -3,13 +3,13 @@ mutable struct Fluctuation <: MLJBase.Supervised initial_factors::MLCMRelevantFactors tol::Union{Nothing, Float64} max_iter::Int - ps_lowerbound::Float64 + ps_lowerbound::Union{Float64, Nothing} weighted::Bool cache::Bool prevalence_weights::Union{Nothing, Vector{Float64}} end -Fluctuation(Ψ, initial_factors; tol=nothing, max_iter=1, ps_lowerbound=1e-8, weighted=false, cache=false, prevalence_weights=nothing) = +Fluctuation(Ψ, initial_factors; tol=nothing, max_iter=1, ps_lowerbound=nothing, weighted=false, cache=false, prevalence_weights=nothing) = Fluctuation(Ψ, initial_factors, tol, max_iter, ps_lowerbound, weighted, cache, prevalence_weights) one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:MLJBase.Continuous} = LinearRegressor(fit_intercept=false, offsetcol = :offset) @@ -24,11 +24,23 @@ The GLM models require inputs of the same type, which sometimes is not the case same_type_df(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} = DataFrame(covariate=covariate, offset=convert(Vector{T1}, offset)) -function fluctuation_input(covariate, ŷ) - offset = compute_offset(ŷ) +function fluctuation_input(covariate::AbstractVector, ŷ) + offset = compute_offset(ŷ) return same_type_df(covariate, offset) end +function fluctuation_input(covariate::AbstractMatrix, ŷ) + offset = compute_offset(ŷ) + T = eltype(covariate) + typed_offset = eltype(offset) == T ? offset : convert(Vector{T}, offset) + K = size(covariate, 2) + df = DataFrame(:offset => typed_offset) + for k in 1:K + df[!, Symbol("H_", k)] = covariate[:, k] + end + return df +end + hasconverged(gradient, tol) = abs(mean(gradient)) < tol """ @@ -53,13 +65,13 @@ If prevalence weights are provided, they are applied to the weights and normaliz function initialize_observed_cache(model, X, y) Q⁰ = model.initial_factors.outcome_mean G⁰ = model.initial_factors.propensity_score - H, w = clever_covariate_and_weights( + H, w, signs = clever_covariate_and_weights( model.Ψ, G⁰, X; ps_lowerbound=model.ps_lowerbound, weighted_fluctuation=model.weighted ) - ŷ = MLJBase.predict(Q⁰, X) - return Dict{Symbol, Any}(:H => H, :w => w, :ŷ => ŷ, :y => float(y)) + ŷ = MLJBase.predict(Q⁰, X) + return Dict{Symbol, Any}(:H => H, :w => w, :signs => signs, :ŷ => ŷ, :y => float(y)) end """ @@ -84,7 +96,7 @@ function initialize_counterfactual_cache(model, X) T_ct = counterfactualTreatment(vals, Ttemplate) X_ct = DataFrame((;(Symbol(colname) => colname ∈ names(T_ct) ? T_ct[!, colname] : X[!, colname] for colname in names(X))...)) - covariates_ct, w_ct = clever_covariate_and_weights(Ψ, + covariates_ct, _, _ = clever_covariate_and_weights(Ψ, G⁰, X_ct; ps_lowerbound=model.ps_lowerbound, @@ -137,7 +149,8 @@ function compute_gradient_and_estimate_from_caches!( # Compute gradient Ey = expected_value(observed_cache[:ŷ]) ct_aggregate = compute_counterfactual_aggregate!(counterfactual_cache, Q) - gradient_Y_X = ∇YX(observed_cache[:H], observed_cache[:y], Ey, observed_cache[:w]) + H_combined = observed_cache[:H] * observed_cache[:signs] + gradient_Y_X = ∇YX(H_combined, observed_cache[:y], Ey, observed_cache[:w]) gradient, Ψ̂ = gradient_and_estimate(ct_aggregate, gradient_Y_X, observed_cache[:y], prevalence_weights) return gradient, Ψ̂ end @@ -219,6 +232,7 @@ function gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) ctl_sum = q̄₀_over_J * sum(ct_aggregate_controls_batch) gradient[case_id] = q₀ * (gradient_Y_X_cases[case_id] + ct_aggregate_case) + ctl_sum end + point_estimate /= nC gradient .-= point_estimate @@ -234,7 +248,10 @@ end get_fluctuation_weights(prevalence_weights::Nothing, clever_covariate_weights) = clever_covariate_weights -get_fluctuation_weights(prevalence_weights, clever_covariate_weights) = clever_covariate_weights .* prevalence_weights +function get_fluctuation_weights(prevalence_weights, clever_covariate_weights) + normalised_prevalence_weights = (prevalence_weights / sum(prevalence_weights)) * length(prevalence_weights) + return clever_covariate_weights .* normalised_prevalence_weights +end """ MLJBase.fit(model::Fluctuation, verbosity, X, y) @@ -307,7 +324,7 @@ end Generates initial predictions and iteratively predicts from the fitted fluctuations. """ function MLJBase.predict(model::Fluctuation, machines, X) - covariate, _ = clever_covariate_and_weights( + covariate, _, _ = clever_covariate_and_weights( model.Ψ, model.initial_factors.propensity_score, X; ps_lowerbound=model.ps_lowerbound, weighted_fluctuation=model.weighted diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index 11f6466a..57bf487f 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -52,16 +52,17 @@ Computes the projection of the gradient on the (Y | X) space. This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ -function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8) +function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=nothing) # Maybe can cache some results (H and E[Y|X]) to improve perf here - H, w = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound) + H, w, signs = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound) y = float(dataset[!, Q.estimand.outcome]) Ey = expected_value(Q, dataset) - return ∇YX(H, y, Ey, w) + H_combined = H * signs + return ∇YX(H_combined, y, Ey, w) end -function gradient_and_plugin_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8) +function gradient_and_plugin_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=nothing) Q = factors.outcome_mean G = factors.propensity_score ctf_agg = counterfactual_aggregate(Ψ, Q, dataset) diff --git a/src/counterfactual_mean_based/nuisance_estimators.jl b/src/counterfactual_mean_based/nuisance_estimators.jl index 319c4ba6..825a75be 100644 --- a/src/counterfactual_mean_based/nuisance_estimators.jl +++ b/src/counterfactual_mean_based/nuisance_estimators.jl @@ -320,7 +320,7 @@ end function CMBasedFoldsTMLE(Ψ, initial_factors_estimate, train_validation_indices; tol=nothing, max_iter=1, - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted=false, machine_cache=false, ) @@ -461,7 +461,7 @@ function get_targeted_estimator( initial_factors_estimate; tol=nothing, max_iter=1, - ps_lowerbound=1e-8, + ps_lowerbound=nothing, weighted=true, machine_cache=false, models=nothing, diff --git a/src/estimators.jl b/src/estimators.jl index 6c489346..f33247d4 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -48,7 +48,9 @@ function fit_mlj_model(model, X, y; parents=names(X), cache=false, weights=nothi mach = machine(model, X, y; cache=cache) else if supervised_learner_supports_weights(model) - mach = machine(model, X, y, weights; cache=cache) + # Normalise weights at point-of-use + normalised_weights = (weights / sum(weights)) * length(weights) + mach = machine(model, X, y, normalised_weights; cache=cache) else throw(ArgumentError("The model $(model) does not support weights and cannot be used with prevalence.")) end @@ -61,6 +63,7 @@ end compute_prevalence_weights(prevalence, y) Calculates weights for a case-control study to use in the fitting of nuisance functions. +Returns raw (unnormalised) weights. Normalisation happens at point-of-use in fit_mlj_model and fluctuation. - `prevalence`: The prevalence of the outcome in the population. - `y`: The outcome variable across observations, which should be binary vector.` """ diff --git a/src/utils.jl b/src/utils.jl index c54d4280..c9bbab0a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,7 +35,7 @@ This is to avoid the expensive complications of: - Tracking sample_ids """ function choose_initial_dataset(dataset, fluctuation_dataset; train_validation_indices=nothing, prevalence=nothing) - # In CV mode or prevalence mode, we get back to the no fluctuation_dataset + # In CV mode or case-control weighted mode, we get back to the no fluctuation_dataset if !isnothing(train_validation_indices) || !isnothing(prevalence) return fluctuation_dataset else @@ -89,6 +89,31 @@ function indicator_values(indicators, T) return indic end +""" + indicator_matrix(indicators, T) + +Returns a tuple `(mat, signs)` where: +- `mat` is an `n × K` matrix of unsigned indicator columns (1.0 where treatment + values match, 0.0 otherwise), one column per entry in `indicators`. +- `signs` is a length-K vector of the corresponding signs from `indicators`. +""" +function indicator_matrix(indicators, T) + n = nrows(T) + indicator_keys = sort(collect(keys(indicators)), by=string) + signs = [indicators[k] for k in indicator_keys] + K = length(indicator_keys) + mat = zeros(Float64, n, K) + for (i, row) in enumerate(Tables.namedtupleiterator(T)) + vals = values(row) + for (j, key) in enumerate(indicator_keys) + if vals == key + mat[i, j] = 1.0 + end + end + end + return mat, signs +end + expected_value(ŷ::AbstractArray{<:UnivariateFinite{<:Union{OrderedFactor{2}, Multiclass{2}}}}) = pdf.(ŷ, levels(first(ŷ))[2]) expected_value(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ) expected_value(ŷ::AbstractVector{<:Real}) = ŷ diff --git a/test/counterfactual_mean_based/case_control_weighted_tmle.jl b/test/counterfactual_mean_based/case_control_weighted_tmle.jl index f64ea7d8..ea90db0e 100644 --- a/test/counterfactual_mean_based/case_control_weighted_tmle.jl +++ b/test/counterfactual_mean_based/case_control_weighted_tmle.jl @@ -104,5 +104,65 @@ end @test mean(ccw_coverage) > 0.80 end +@testset "CCW-OSE bootstrapping test" begin + Random.seed!(42) + Npop = 2_000_000 + # Simulate population + W = rand(Bernoulli(0.5), Npop) + ηA = -0.2 .+ 0.8 .* W + pA = 1 ./ (1 .+ exp.(-ηA)) + A = rand.(Bernoulli.(pA)) + α, β, γ = -3, log(2), log(1.5) + pY = pY_given_A_W(A, W; α=α, β=β, γ=γ) + Y = rand.(Bernoulli.(pY)) + pop = DataFrame(W=W, A=A, Y=Y) + q₀ = mean(pop.Y .== 1) + + # Obtain the true risk difference (ATE) + true_rd = mean(pY_given_A_W(1, pop.W) .- pY_given_A_W(0, pop.W)) + + # Define ATE estimand + Ψ = ATE( + outcome=:Y, + treatment_values=(A=(case=true, control=false),), + treatment_confounders=(A=[:W],) + ) + # Standard OSE (no prevalence correction) + ose_std = Ose() + # CCW-OSE (with true prevalence) + ose_ccw = Ose(prevalence=q₀) + + # Draw a series of biased samples of size n_sample + n_sample = 10_000 + cc_prev = 0.2 + B = 30 + ccw_ose_results = Vector{Any}(undef, B) + std_ose_results = Vector{Any}(undef, B) + ccw_coverage = Vector{Bool}(undef, B) + std_coverage = Vector{Bool}(undef, B) + + for i in 1:B + sample = subsample_case_control(pop, n_sample, cc_prev, rng=Random.MersenneTwister(i)) + std_result, _ = ose_std(Ψ, sample; verbosity=0) + ccw_result, _ = ose_ccw(Ψ, sample; verbosity=0) + std_ose_results[i] = std_result.estimate + ccw_ose_results[i] = ccw_result.estimate + # Compare bias: CCW-OSE should be much less biased than standard OSE + std_bias = abs(std_result.estimate - true_rd) + ccw_bias = abs(ccw_result.estimate - true_rd) + @test ccw_bias < std_bias / 2 + # Retrieve coverage + lb, ub = confint(significance_test(ccw_result)) + ccw_coverage[i] = lb < true_rd < ub + lb, ub = confint(significance_test(std_result)) + std_coverage[i] = lb < true_rd < ub + end + # See if, on average, CCW-OSE outperforms standard OSE + @test (mean(ccw_ose_results) - true_rd) < (mean(std_ose_results) - true_rd) + # Test coverage is improved as well + @test mean(ccw_coverage) > mean(std_coverage) + @test mean(ccw_coverage) > 0.80 +end + end true \ No newline at end of file diff --git a/test/counterfactual_mean_based/clever_covariate.jl b/test/counterfactual_mean_based/clever_covariate.jl index ebb399f3..df7ca277 100644 --- a/test/counterfactual_mean_based/clever_covariate.jl +++ b/test/counterfactual_mean_based/clever_covariate.jl @@ -12,7 +12,7 @@ using DataFrames # Nothing results in data adaptive lower bound no lower than max_lb @test TMLE.ps_lower_bound(n, nothing) == max_lb @test TMLE.data_adaptive_ps_lower_bound(n) == max_lb - @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 + @test TMLE.data_adaptive_ps_lower_bound(1000) ≈ 5 / (√1000 * log(1000)) # Otherwise use the provided threhsold provided it's lower than max_lb @test TMLE.ps_lower_bound(n, 1e-8) == 1.0e-8 @test TMLE.ps_lower_bound(n, 1) == 0.1 @@ -38,20 +38,21 @@ end ) weighted_fluctuation = true ps_lowerbound = 1e-8 - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] + @test size(H) == (7, 2) + @test H * signs == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test w == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] weighted_fluctuation = false - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + @test H * signs == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] @test w == ones(7) end @@ -84,11 +85,12 @@ end ps_lowerbound = 1e-8 weighted_fluctuation = false - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 + @test size(H) == (7, 4) + @test H * signs ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] end @@ -129,14 +131,57 @@ end verbosity=0 ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 + @test size(H) == (7, 8) + @test H * signs ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] end +@testset "Test clever_covariate_and_weights: Data-Adapative 1 treatment" begin + Ψ = ATE( + outcome=:Y, + treatment_values=(T=(case="a", control="b"),), + treatment_confounders=(T=[:W],), + ) + n = 150 + dataset = DataFrame( + T = categorical(vcat(fill("a", 149), ["b"])), + Y = collect(1.0:n), + W = rand(n), + ) + + propensity_score_estimator = TMLE.JointConditionalDistributionEstimator(Dict(:T => TMLE.MLConditionalDistributionEstimator(ConstantClassifier()))) + propensity_score_estimate = propensity_score_estimator( + (TMLE.ConditionalDistribution(:T, [:W]),), + dataset, + verbosity=0 + ) + + # Se + weighted_fluctuation = true + ps_lowerbound = nothing + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + + @test w[1:149] ≈ fill(1/(149/150), 149) + @test w[150] ≈ 1/(5/(sqrt(n)*log(n))) + + weighted_fluctuation = false + H, w, signs = TMLE.clever_covariate_and_weights(Ψ, propensity_score_estimate, dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + combined = H * signs + @test combined[1:149] ≈ fill(1/(149/150), 149) + @test combined[150] ≈ -1/(5/(sqrt(n)*log(n))) + @test w == ones(150) +end + end true \ No newline at end of file diff --git a/test/counterfactual_mean_based/covariate_based_strategies.jl b/test/counterfactual_mean_based/covariate_based_strategies.jl index 5885adac..7fbc06f6 100644 --- a/test/counterfactual_mean_based/covariate_based_strategies.jl +++ b/test/counterfactual_mean_based/covariate_based_strategies.jl @@ -65,7 +65,7 @@ include(joinpath(TEST_DIR, "counterfactual_mean_based", "aie_simulations.jl")) @test adaptive_strategy.remaining_confounders == Set{Symbol}() @test adaptive_strategy.current_confounders == Set([:W₁, :W₂, :W₃]) - # Full run: this leads to only W₃ being used for the propensity score + # Full run: this leads to W₁ and W₃ being used for the propensity score ctmle = Tmle(collaborative_strategy=adaptive_strategy) Ψ = AIE( outcome = :Y, @@ -77,8 +77,8 @@ include(joinpath(TEST_DIR, "counterfactual_mean_based", "aie_simulations.jl")) ) result_ctmle, cache = ctmle(Ψ, dataset;verbosity=0); targeted_η̂ = cache[:targeted_factors] - @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₃)) - @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₃,)) + @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁, :W₃)) + @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₁, :W₃)) end @testset "Test GreedyStrategy Interface" begin @@ -158,19 +158,19 @@ end cache=cache, machine_cache=machine_cache, ) - @test new_g == (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁)), TMLE.ConditionalDistribution(:T₂, (:W₁,))) - @test new_ĝ == ĝ + @test new_g == (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₂)), TMLE.ConditionalDistribution(:T₂, ())) + @test new_ĝ == ĝ # Update the collaborative strategy - TMLE.update!(collaborative_strategy, new_g, new_ĝ) - @test collaborative_strategy.remaining_confounders == Set{Symbol}([:W₂, :W₃]) - @test collaborative_strategy.current_confounders == Set{Symbol}([:W₁]) + TMLE.update!(collaborative_strategy, new_g, new_ĝ) + @test collaborative_strategy.remaining_confounders == Set{Symbol}([:W₁, :W₃]) + @test collaborative_strategy.current_confounders == Set{Symbol}([:W₂]) # Let's iterate again step_k_candidate_iterator = TMLE.StepKPropensityScoreIterator(collaborative_strategy, Ψ, dataset, models, new_targeted_η̂ₙ) g_ĝ_candidates = collect(step_k_candidate_iterator) g_candidates = Set(first.(g_ĝ_candidates)) @test g_candidates == Set([ (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁, :W₂)), TMLE.ConditionalDistribution(:T₂, (:W₁,))), - (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁)), TMLE.ConditionalDistribution(:T₂, (:W₁, :W₃))) + (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₂)), TMLE.ConditionalDistribution(:T₂, (:W₃,))) ]) # Find optimal candidate again new_g, new_ĝ, new_targeted_η̂ₙ, new_loss, use_fluct = TMLE.step_k_best_candidate( @@ -187,7 +187,7 @@ end ) @test new_g == (TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁, :W₂)), TMLE.ConditionalDistribution(:T₂, (:W₁,))) - # Full run: this leads to only W₃ being used for the propensity score + # Full run: this leads to only W₁ being used for the propensity score ctmle = Tmle(collaborative_strategy=collaborative_strategy) Ψ = AIE( outcome = :Y, @@ -199,8 +199,8 @@ end ) result_ctmle, cache = ctmle(Ψ, dataset;verbosity=0); targeted_η̂ = cache[:targeted_factors] - @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₃)) - @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₃,)) + @test targeted_η̂.propensity_score.components[1].estimand == TMLE.ConditionalDistribution(:T₁, (:T₂, :W₁)) + @test targeted_η̂.propensity_score.components[2].estimand == TMLE.ConditionalDistribution(:T₂, (:W₁,)) end @testset "Integration Test using the AdaptiveCorrelationStrategy" begin @@ -367,7 +367,7 @@ end @test fitted_propensity_score === new_g ## The loss should be smaller because we fluctuate through the previous model, however in finite samples ## I suppose this is not warranted, we check they are approximately equal and the loss with Q̄n,k,* < Q̄n,k - @test loss ≈ new_loss atol=1e-5 + @test loss ≈ new_loss atol=1e-4 # We can pretend the step_k_best_candidate had used the non targeted outcome model new_targeted_η̂ₙ_bis, new_loss_bis = TMLE.get_new_targeted_candidate(targeted_η̂ₙ, new_ĝₙ, fluctuation_model, dataset; use_fluct=false, diff --git a/test/counterfactual_mean_based/double_robustness_aie.jl b/test/counterfactual_mean_based/double_robustness_aie.jl index ab515d9e..38703847 100644 --- a/test/counterfactual_mean_based/double_robustness_aie.jl +++ b/test/counterfactual_mean_based/double_robustness_aie.jl @@ -13,7 +13,8 @@ include(joinpath(TEST_DIR, "helper_fns.jl")) include(joinpath(TEST_DIR, "counterfactual_mean_based", "aie_simulations.jl")) cont_interacter = InteractionTransformer(order=2) |> LinearRegressor -cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1.) +# add some regularization +cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1e-6) @testset "Test Double Robustness AIE on binary_outcome_binary_treatment_pb" begin @@ -38,7 +39,7 @@ cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1. ) dr_estimators = double_robust_estimators(models, resampling=StratifiedCV()) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) - test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-6) test_mean_inf_curve_almost_zero(results.ose; atol=1e-9) # The initial estimate is far away naive = Plugin(models[:Y]) @@ -53,7 +54,7 @@ cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1. ) dr_estimators = double_robust_estimators(models, resampling=StratifiedCV()) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) - test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-7) test_mean_inf_curve_almost_zero(results.ose; atol=1e-9) # The initial estimate is far away naive = Plugin(models[:Y]) @@ -144,10 +145,10 @@ end test_mean_inf_curve_almost_zero(results.tmle; atol=1e-5) test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) - # The initial estimate is far away + # The initial plugin estimate is close to truth (Q is well specified) naive = Plugin(models[:Y]) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) - @test naive_result ≈ -0.02 atol=1e-2 + @test naive_result ≈ Ψ₀ atol=5e-2 end diff --git a/test/counterfactual_mean_based/estimators_and_estimates.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl index 78cd4eb1..32d2e0db 100644 --- a/test/counterfactual_mean_based/estimators_and_estimates.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -48,7 +48,7 @@ end (:info, TMLE.fit_string(Q)) ) cache = Dict() - η̂ₙ = @test_logs fit_log... η̂(η, dataset; cache=cache, verbosity=1) + η̂ₙ = @test_logs fit_log... match_mode=:any η̂(η, dataset; cache=cache, verbosity=1) # Test both sub estimands have been fitted @test η̂ₙ.outcome_mean isa TMLE.MLConditionalDistribution @test fitted_params(η̂ₙ.outcome_mean.machine) isa NamedTuple @@ -59,7 +59,7 @@ end # Both models unchanged, η̂ₙ is fully reused new_η̂ = TMLE.CMRelevantFactorsEstimator(models=models) full_reuse_log = (:info, TMLE.reuse_string(η)) - @test_logs full_reuse_log new_η̂(η, dataset; cache=cache, verbosity=1) + @test_logs full_reuse_log match_mode=:any new_η̂(η, dataset; cache=cache, verbosity=1) # Changing one model, only the other one is refitted models[:T₁] = LogisticClassifier(fit_intercept=false) new_η̂ = TMLE.CMRelevantFactorsEstimator(models=models) @@ -68,7 +68,7 @@ end (:info, TMLE.fit_string(G[1])), (:info, TMLE.reuse_string(Q)) ) - @test_logs partial_reuse_log... new_η̂(η, dataset; cache=cache, verbosity=1) + @test_logs partial_reuse_log... match_mode=:any new_η̂(η, dataset; cache=cache, verbosity=1) # Adding a resampling strategy cv_fit_log = ( @@ -78,7 +78,7 @@ end ) train_validation_indices = MLJBase.train_test_pairs(CV(nfolds=3), 1:nrow(dataset), dataset) resampled_η̂ = TMLE.CMRelevantFactorsEstimator(models=models, train_validation_indices=train_validation_indices) - η̂ₙ = @test_logs cv_fit_log... resampled_η̂(η, dataset; cache=cache, verbosity=1) + η̂ₙ = @test_logs cv_fit_log... match_mode=:any resampled_η̂(η, dataset; cache=cache, verbosity=1) @test length(η̂ₙ.outcome_mean.machines) == 3 ps_component = only(η̂ₙ.propensity_score.components) @test length(ps_component.machines) == 3 diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index 47c33147..f5648ce5 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -43,20 +43,20 @@ using MLJGLMInterface expected_value = [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0] # Constant predictions of the mean as per Q⁰ @test mean.(counterfactual_cache.predictions[1]) == mean.(counterfactual_cache.predictions[2]) == expected_value @test counterfactual_cache.signs == [1., -1.] - @test counterfactual_cache.covariates == [ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0] - ] + # Counterfactual covariates are now n×K matrices + @test size(counterfactual_cache.covariates[1]) == (7, 2) + @test size(counterfactual_cache.covariates[2]) == (7, 2) observed_cache = TMLE.initialize_observed_cache(weighted_fluctuation, X, y) - @test observed_cache[:ŷ] isa Vector{<:Normal} - @test observed_cache[:H] == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] # This is used to fit, so weight has been removed + @test observed_cache[:ŷ] isa Vector{<:Normal} + @test size(observed_cache[:H]) == (7, 2) + @test observed_cache[:H] * observed_cache[:signs] == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test observed_cache[:w] == [1.75, 3.5, 7., 1.75, 1.75, 3.5, 1.75] # weight is separate @test observed_cache[:y] isa Vector{Float64} ## Second fit the fluctuation w_machines, cache, w_report = MLJBase.fit(weighted_fluctuation, 0, X, y) ### Only one machine, only fitted the clever covariate mach = only(w_machines) - @test fitted_params(mach).features == [:covariate] + @test fitted_params(mach).features == [:H_1, :H_2] ### Report entries @test length(w_report.epsilons) == length(w_report.estimates) == length(w_report.gradients) == 1 @test w_report.epsilons[1][1] !== 0 @@ -70,11 +70,11 @@ using MLJGLMInterface unweighted_fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; weighted=false, tol=0, max_iter=3) ## First check the weight and covariates from the observed cache observed_cache = TMLE.initialize_observed_cache(unweighted_fluctuation, X, y) - @test observed_cache[:H] == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + @test observed_cache[:H] * observed_cache[:signs] == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] @test observed_cache[:w] == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ## Second fit the fluctuation logs = [(:info, "TMLE step: 1."), (:info, "TMLE step: 2."), (:info, "TMLE step: 3."), (:info, "Convergence criterion not reached.")] - uw_machines, cache, uw_report = @test_logs logs... MLJBase.fit(unweighted_fluctuation, 1, X, y); + uw_machines, cache, uw_report = @test_logs logs... match_mode=:any MLJBase.fit(unweighted_fluctuation, 1, X, y); ### Only one machine, only fitted the clever covariate @test length(uw_machines) == 3 ### Report entries @@ -138,7 +138,8 @@ end @test length(counterfactual_cache.covariates) == 4 observed_cache = TMLE.initialize_observed_cache(fluctuation, X, y) @test observed_cache[:ŷ] isa UnivariateFiniteVector - @test isapprox(observed_cache[:H], [2.44, -3.26, -3.26, 2.44, 2.44, -6.12, 8.16], atol=0.1) + @test size(observed_cache[:H]) == (7, 4) + @test isapprox(observed_cache[:H] * observed_cache[:signs], [2.44, -3.26, -3.26, 2.44, 2.44, -6.12, 8.16], atol=0.1) @test observed_cache[:w] == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] @test observed_cache[:y] isa Vector{Float64} @@ -190,7 +191,7 @@ end X = dataset[!, collect(η̂ₙ.outcome_mean.estimand.parents)] y = dataset[!, η̂ₙ.outcome_mean.estimand.outcome] - fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; weighted=false, prevalence_weights=prevalence_weights, max_iter=5) + fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; weighted=false, prevalence_weights=prevalence_weights, max_iter=5, ps_lowerbound=nothing) machs, cache, report = MLJBase.fit(fluctuation, 0, X, y) gradient = report.gradients[1] @test mean(gradient) ≈ 0.0 atol=1e-4 @@ -205,17 +206,18 @@ end q_0 = 0.8 q_0_bar_over_J = 0.2 gradient, point_estimate = TMLE.gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) - # The 7ths control is not used in these computations - @test point_estimate == 0.5*( - (q_0 * 2) + q_0_bar_over_J * (1 + 3) # First case (idx=2) grouped with controls (idx=[1, 3]) - + - (q_0 * 4) + q_0_bar_over_J * (5 + 6) # Second case (idx=4) grouped with controls (idx=[5, 6]) - ) - # gradient_Y_X and ct_aggregate are summed together and the point estimate is removed - @test gradient == [ - (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1) + (3 + 3)), # First case (idx=2) grouped with controls (idx=[1, 3]) - (q_0 * (4 + 4)) + q_0_bar_over_J * ((5 + 5) + (6 + 6)) # Second case (idx=4) grouped with controls (idx=[5, 6]) - ] .- point_estimate + # The 7th control is not used in these computations (J = 5÷2 = 2 controls per case) + # Cases at indices 2, 4; Controls at indices 1, 3, 5, 6, 7 (7th dropped) + # point_estimate = sum over cases of (q₀*ct_case + q̄₀/J*sum(ct_controls)) / nC + nC = 2 + raw_point_sum = (q_0 * 2) + q_0_bar_over_J * (1 + 3) + (q_0 * 4) + q_0_bar_over_J * (5 + 6) + @test point_estimate ≈ raw_point_sum / nC atol=1e-10 + # gradient = q₀*(ct_case + grad_case) + q̄₀/J*sum(ct_ctl + grad_ctl) - point_estimate + raw_gradient = [ + (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1) + (3 + 3)), + (q_0 * (4 + 4)) + q_0_bar_over_J * ((5 + 5) + (6 + 6)) + ] + @test gradient ≈ raw_gradient .- point_estimate atol=1e-10 # When there is exactly J controls per case, all controls are used ct_aggregate = [1, 2, 3, 4] @@ -225,17 +227,17 @@ end q_0 = 0.8 q_0_bar_over_J = 0.2 gradient, point_estimate = TMLE.gradient_and_estimate(ct_aggregate, gradient_Y_X, y, weights) - # The 7ths control is not used in these computations + # Cases at indices 2, 4; Controls at indices 1, 3 (J = 2÷2 = 1 control per case) @test point_estimate ≈ 0.5*( - (q_0 * 2) + (q_0_bar_over_J * 1) # First case (idx=2) grouped with controls (idx=[1, 3]) + (q_0 * 2) + (q_0_bar_over_J * 1) + - (q_0 * 4) + (q_0_bar_over_J * 3) # Second case (idx=4) grouped with controls (idx=[5, 6]) + (q_0 * 4) + (q_0_bar_over_J * 3) ) atol=1e-10 - # gradient_Y_X and ct_aggregate are summed together and the point estimate is removed - @test gradient == [ - (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1)), # First case (idx=2) grouped with controls (idx=[1, 3]) - (q_0 * (4 + 4)) + q_0_bar_over_J * ((3 + 3)) # Second case (idx=4) grouped with controls (idx=[5, 6]) - ] .- point_estimate + # gradient_Y_X and ct_aggregate are summed together and the point estimate is removed + @test gradient ≈ [ + (q_0 * (2 + 2)) + q_0_bar_over_J * ((1 + 1)), + (q_0 * (4 + 4)) + q_0_bar_over_J * ((3 + 3)) + ] .- point_estimate atol=1e-10 end end diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 4813ac10..f182b500 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -11,10 +11,10 @@ using JSON using YAML function regression_tests(tmle_result) - @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 + @test estimate(tmle_result) ≈ -0.184910 atol = 1e-6 l, u = confint(significance_test(tmle_result)) - @test l ≈ -0.279246 atol = 1e-6 - @test u ≈ -0.091821 atol = 1e-6 + @test l ≈ -0.278604 atol = 1e-6 + @test u ≈ -0.091215 atol = 1e-6 @test OneSampleZTest(tmle_result) isa OneSampleZTest end diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 9fbc8fe9..238b181b 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -26,7 +26,7 @@ reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) cache = Dict() # Model that supports weights estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(),prevalence_weights=weights) - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) # Model that does NOT support weights (e.g., LogisticClassifier) estimator = TMLE.MLConditionalDistributionEstimator(LogisticClassifier(), prevalence_weights=weights) @@ -34,7 +34,7 @@ reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) # Pipeline that supports weights estimator = TMLE.MLConditionalDistributionEstimator(with_encoder(LinearBinaryClassifier()), prevalence_weights=weights) - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) # Pipeline that does NOT support weights estimator = TMLE.MLConditionalDistributionEstimator(with_encoder(LogisticClassifier()), prevalence_weights=weights) @@ -46,7 +46,7 @@ end estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) # Fitting with no cache cache = Dict() - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) expected_features = collect(estimand.parents) @test conditional_density_estimate isa TMLE.MLConditionalDistribution @test fitted_params(conditional_density_estimate.machine).features == expected_features @@ -61,10 +61,10 @@ end # Check cache management ## Uses the cache instead of fitting new_estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) - @test_logs (:info, reuse_log) estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, reuse_log) match_mode=:any estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) ## Changing the model leads to refit new_estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(fit_intercept=false)) - @test_logs (:info, fit_log) new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, fit_log) match_mode=:any new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) end @testset "Test MLConditionalDistributionEstimator: continuous outcome" begin @@ -72,7 +72,7 @@ end ## Probabilistic Model model = MLJGLMInterface.LinearRegressor() estimator = TMLE.MLConditionalDistributionEstimator(model) - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, continuous_dataset; cache=Dict(), verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, continuous_dataset; cache=Dict(), verbosity=verbosity) ŷ = MLJBase.predict(conditional_density_estimate, continuous_dataset) @test ŷ isa Vector{Normal{Float64}} μ̂ = TMLE.expected_value(conditional_density_estimate, continuous_dataset) @@ -93,33 +93,32 @@ end end @testset "Test MLConditionalDistributionEstimator: binary outcome with prevalence weights" begin - # Simulate a binary outcome with imbalanced classes - n = 100 - X, y = make_moons(n) - y = copy(y) + # Binary outcome with imbalanced classes + y = copy(binary_dataset[!, :Y]) y[1:80] .= 0 # Make one class more prevalent y[81:100] .= 1 - binary_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) + case_control_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) # Set prevalence to 0.5 (true prevalence in population) prevalence = 0.5 - weights = TMLE.compute_prevalence_weights(prevalence, binary_dataset.Y) - @test Set(weights) == Set([0.5, 0.125]) # Check weights are correct + weights = TMLE.compute_prevalence_weights(prevalence, case_control_dataset.Y) + @test Set(weights) == Set([0.5, 0.125]) # Weights should not be normalised + estimand = TMLE.ConditionalDistribution(:Y, [:X₁, :X₂]) estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(), nothing, weights) # Fit with prevalence weights cache = Dict() - conditional_density_estimate = estimator(estimand, binary_dataset; cache=cache, verbosity=1) + conditional_density_estimate = estimator(estimand, case_control_dataset; cache=cache, verbosity=1) @test conditional_density_estimate isa TMLE.MLConditionalDistribution # Check that predictions are probabilities - ŷ = MLJBase.predict(conditional_density_estimate, binary_dataset) + ŷ = MLJBase.predict(conditional_density_estimate, case_control_dataset) @test all(0.0 .<= [ŷ[i].prob_given_ref[2] for i in eachindex(ŷ)] .<= 1.0) # Check that weights are used (by comparing with unweighted fit) estimator_unweighted = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) - conditional_density_estimate_unweighted = estimator_unweighted(estimand, binary_dataset; cache=Dict(), verbosity=0) - ŷ_unweighted = MLJBase.predict(conditional_density_estimate_unweighted, binary_dataset) + conditional_density_estimate_unweighted = estimator_unweighted(estimand, case_control_dataset; cache=Dict(), verbosity=0) + ŷ_unweighted = MLJBase.predict(conditional_density_estimate_unweighted, case_control_dataset) μ̂_weighted = [ŷ[i].prob_given_ref[2] for i in eachindex(ŷ)] μ̂_unweighted = [ŷ_unweighted[i].prob_given_ref[2] for i in eachindex(ŷ_unweighted)] @test !all(isapprox.(μ̂_weighted, μ̂_unweighted; atol=1e-6)) # Should differ @@ -135,7 +134,7 @@ end train_validation_indices ) cache = Dict() - conditional_density_estimate = @test_logs (:info, fit_log) estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) + conditional_density_estimate = @test_logs (:info, fit_log) match_mode=:any estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) @test conditional_density_estimate isa TMLE.SampleSplitMLConditionalDistribution expected_features = collect(estimand.parents) @test all(fitted_params(mach).features == expected_features for mach in conditional_density_estimate.machines) @@ -159,21 +158,21 @@ end LinearBinaryClassifier(), train_validation_indices ) - @test_logs (:info, reuse_log) estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) + @test_logs (:info, reuse_log) match_mode=:any estimator(estimand, binary_dataset;cache=cache, verbosity=verbosity) ## Changing the model leads to refit new_model = LinearBinaryClassifier(fit_intercept=false) new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( new_model, train_validation_indices ) - @test_logs (:info, fit_log) new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, fit_log) match_mode=:any new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) ## Changing the train/validation splits leads to refit train_validation_indices = Tuple(MLJBase.train_test_pairs(CV(nfolds=4), 1:n, binary_dataset)) new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( new_model, train_validation_indices ) - @test_logs (:info, fit_log) new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) + @test_logs (:info, fit_log) match_mode=:any new_estimator(estimand, binary_dataset; cache=cache, verbosity=verbosity) end @testset "Test SampleSplitMLConditionalDistributionEstimator: Continuous outcome" begin @@ -219,7 +218,7 @@ end prevalence = 0.5 binary_dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) weights = TMLE.compute_prevalence_weights(prevalence, binary_dataset.Y) - @test Set(weights) == Set([0.5, 0.125]) # Check weights are correct + @test Set(weights) == Set([0.5, 0.125]) nfolds = 3 train_validation_indices = Tuple(MLJBase.train_test_pairs(StratifiedCV(nfolds=nfolds), 1:n, binary_dataset, binary_dataset.Y)) estimand = TMLE.ConditionalDistribution(:Y, [:X₁, :X₂]) diff --git a/test/utils.jl b/test/utils.jl index b062eab3..96b00f84 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -200,7 +200,7 @@ end treatment_confounders=[:W] ) # The treatment levels correctly appear in the dataset - dataset = DataFrame(Y=rand(10), T=rand(0:1, 10), W=rand(10)) + dataset = DataFrame(Y=rand(10), T=repeat([0, 1], 5), W=rand(10)) @test TMLE.check_inputs(Ψ, dataset, nothing) isa Any # The treatment levels do not appear in the dataset dataset = DataFrame(Y=rand(10), T=rand(2:3, 10), W=rand(10))