diff --git a/Project.toml b/Project.toml index 4a7e6120..7d80c297 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +GLMNet = "8d5ece8b-de18-5317-b113-243142960cc6" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" @@ -50,6 +51,7 @@ DataFrames = "1.7.0" DifferentiationInterface = "0.6.43" Distributions = "0.25" GLM = "1.8.2" +GLMNet = "0.7.4" Graphs = "1.8" HypothesisTests = "0.10, 0.11" JSON = "0.21.4" diff --git a/examples/lasso_example.jl b/examples/lasso_example.jl new file mode 100644 index 00000000..00c3368f --- /dev/null +++ b/examples/lasso_example.jl @@ -0,0 +1,254 @@ +""" +Example: LASSO Collaborative TMLE with CairoMakie Plots + +Demonstrates CV-based variable selection in high-dimensional causal inference +and generates static plots using CairoMakie from the docs environment. +""" + +using Pkg +Pkg.activate("docs") + +using CairoMakie +using Printf +using Statistics +using Random + +Pkg.activate(".") + +using TMLE +using DataFrames +using CategoricalArrays +using GLMNet +using Distributions +using StatsBase + +""" +Create a Toeplitz matrix manually from a vector +A Toeplitz matrix has constant diagonals, where T[i,j] = c[|i-j|+1] +""" +function create_toeplitz_matrix(c::Vector{T}) where T + n = length(c) + matrix = Matrix{T}(undef, n, n) + + for i in 1:n + for j in 1:n + matrix[i, j] = c[abs(i - j) + 1] + end + end + + return matrix +end + +println("🧬 LASSO Collaborative TMLE Example with CairoMakie Plots") +println("=" ^ 60) + +Random.seed!(123) + +function sim3(; n=1000, p=100, rho=0.9, k=20, amplitude=1.0, amplitude2=1.0, k2=20) + toeplitz_vector = [rho^i for i in 0:(p-1)] + Sigma = create_toeplitz_matrix(toeplitz_vector) + + mv_normal = MvNormal(zeros(p), Matrix(Sigma)) + W_raw = rand(mv_normal, n)' + W = (W_raw .- mean(W_raw, dims=1)) ./ std(W_raw, dims=1) + + nonzero2 = sample(1:p, k2, replace=false) + signs2 = sample([-1, 1], p, replace=true) + beta = amplitude2 * signs2 .* [i in nonzero2 for i in 1:p] + + logit_p = W * beta + prob_A = 1 ./ (1 .+ exp.(-logit_p)) + A = rand.(Bernoulli.(prob_A)) + + nonzero = sample(1:p, k, replace=false) + signs = sample([-1, 1], p, replace=true) + gamma = amplitude * signs .* [i in nonzero for i in 1:p] + + Y = 2.0 * A + W * gamma + randn(n) + + W_df = DataFrame(W, [Symbol("W$i") for i in 1:p]) + data = hcat(W_df, DataFrame(A=categorical(A), Y=Y)) + + return data, nonzero, nonzero2 +end + +println("\n📊 Generating high-dimensional simulation data...") +n = 2000 +p = 30 +rho = 0.5 +n_bootstrap = 100 + +println("Simulation parameters:") +println(" Sample size: $n") +println(" Confounders: $p") +println(" Correlation: $rho") +println(" Bootstrap samples: $n_bootstrap") + +dataset, true_outcome_vars, true_ps_vars = sim3(n=n, p=p, rho=rho, k=15, k2=15) +all_confounders = [Symbol("W$i") for i in 1:p] + +estimand = ATE( + outcome = :Y, + treatment_values = (A = (case = 1, control = 0),), + treatment_confounders = (A = all_confounders,) +) + +println("\n🔄 Running bootstrap comparison...") +println("=" ^ 50) + +standard_estimates = Float64[] +lasso_estimates = Float64[] + +print("Progress: ") +for i in 1:n_bootstrap + if i % 10 == 0 + print("$i ") + end + + boot_indices = sample(1:n, n, replace=true) + boot_dataset = dataset[boot_indices, :] + + # Use GLMNet as the base learners for a fair comparison + models_glmnet = TMLE.default_models( + G = GLMNetClassifier(), + Q_continuous = GLMNetRegressor() + ) + + standard_estimator = Tmle(models = models_glmnet) + try + standard_result, _ = standard_estimator(estimand, boot_dataset; verbosity=0) + push!(standard_estimates, estimate(standard_result)) + catch + push!(standard_estimates, NaN) + end + + lasso_strategy = LassoCTMLE( + patience = 4, + alpha = 1.0 + ) + lasso_estimator = Tmle(models = models_glmnet, collaborative_strategy = lasso_strategy) + try + lasso_result, _ = lasso_estimator(estimand, boot_dataset; verbosity=0) + push!(lasso_estimates, estimate(lasso_result)) + catch + push!(lasso_estimates, NaN) + end +end + +println("\n✅ Bootstrap completed!") + +valid_standard = filter(!isnan, standard_estimates) +valid_lasso = filter(!isnan, lasso_estimates) + +println("\nBootstrap Results:") +println("=" ^ 50) +println("Valid estimates:") +println(" Standard TMLE: $(length(valid_standard))/$n_bootstrap") +println(" LASSO CTMLE: $(length(valid_lasso))/$n_bootstrap") + +if length(valid_standard) > 10 && length(valid_lasso) > 10 + println("\nSummary Statistics:") + println("Standard TMLE:") + println(" Mean: $(round(mean(valid_standard), digits=3))") + println(" Std: $(round(std(valid_standard), digits=3))") + println(" Bias: $(round(abs(mean(valid_standard) - 2.0), digits=3))") + + println("LASSO CTMLE:") + println(" Mean: $(round(mean(valid_lasso), digits=3))") + println(" Std: $(round(std(valid_lasso), digits=3))") + println(" Bias: $(round(abs(mean(valid_lasso) - 2.0), digits=3))") + + println("\n📊 Creating CairoMakie plots...") + + fig = Figure(size = (1000, 800)) + + ax1 = Axis(fig[1, 1], + title = "Standard TMLE Distribution", + xlabel = "Estimate Value", + ylabel = "Frequency") + + ax2 = Axis(fig[1, 2], + title = "LASSO CTMLE Distribution", + xlabel = "Estimate Value", + ylabel = "Frequency") + + hist!(ax1, valid_standard, bins=20, color=(:blue, 0.7), strokewidth=1, strokecolor=:blue) + hist!(ax2, valid_lasso, bins=20, color=(:green, 0.7), strokewidth=1, strokecolor=:green) + + vlines!(ax1, [2.0], color=:red, linewidth=2, linestyle=:dash) + vlines!(ax1, [mean(valid_standard)], color=:blue, linewidth=2, linestyle=:dot) + vlines!(ax2, [2.0], color=:red, linewidth=2, linestyle=:dash) + vlines!(ax2, [mean(valid_lasso)], color=:green, linewidth=2, linestyle=:dot) + + ax3 = Axis(fig[2, 1:2], + title = "Bootstrap Distribution Comparison", + xlabel = "Estimate Value", + ylabel = "Frequency") + + hist!(ax3, valid_standard, bins=20, color=(:blue, 0.6), strokewidth=1, strokecolor=:blue, label="Standard TMLE") + hist!(ax3, valid_lasso, bins=20, color=(:green, 0.6), strokewidth=1, strokecolor=:green, label="LASSO CTMLE") + vlines!(ax3, [2.0], color=:red, linewidth=3, linestyle=:dash, label="True ATE = 2.0") + + axislegend(ax3, position=:rt) + + plot_filename = "lasso_ctmle_bootstrap_results.png" + save(plot_filename, fig) + println("📊 Plot saved as: $plot_filename") + + fig2 = Figure(size = (600, 400)) + ax4 = Axis(fig2[1, 1], + title = "Box Plot Comparison", + ylabel = "Estimate Value") + + standard_median = median(valid_standard) + standard_q1 = quantile(valid_standard, 0.25) + standard_q3 = quantile(valid_standard, 0.75) + + lasso_median = median(valid_lasso) + lasso_q1 = quantile(valid_lasso, 0.25) + lasso_q3 = quantile(valid_lasso, 0.75) + + positions = [1, 2] + medians = [standard_median, lasso_median] + q1s = [standard_q1, lasso_q1] + q3s = [standard_q3, lasso_q3] + + for (i, pos) in enumerate(positions) + lines!(ax4, [pos-0.2, pos+0.2, pos+0.2, pos-0.2, pos-0.2], + [q1s[i], q1s[i], q3s[i], q3s[i], q1s[i]], color=:black, linewidth=2) + lines!(ax4, [pos-0.2, pos+0.2], [medians[i], medians[i]], color=:red, linewidth=3) + end + + hlines!(ax4, [2.0], color=:red, linewidth=2, linestyle=:dash) + + ax4.xticks = (positions, ["Standard TMLE", "LASSO CTMLE"]) + + boxplot_filename = "lasso_ctmle_boxplot.png" + save(boxplot_filename, fig2) + println("📊 Box plot saved as: $boxplot_filename") + + println("\n📈 Side-by-Side Comparison:") + println("=" ^ 70) + println("Metric | Standard TMLE | LASSO CTMLE | Difference") + println("-" ^ 70) + @printf("Mean | %12.4f | %12.4f | %+9.4f\n", + mean(valid_standard), mean(valid_lasso), + mean(valid_lasso) - mean(valid_standard)) + @printf("Std Dev | %12.4f | %12.4f | %+9.4f\n", + std(valid_standard), std(valid_lasso), + std(valid_lasso) - std(valid_standard)) + @printf("Bias (from 2.0) | %12.4f | %12.4f | %+9.4f\n", + abs(mean(valid_standard) - 2.0), abs(mean(valid_lasso) - 2.0), + abs(mean(valid_lasso) - 2.0) - abs(mean(valid_standard) - 2.0)) + + variance_reduction = (var(valid_standard) - var(valid_lasso)) / var(valid_standard) * 100 + @printf("Variance Reduction | %12s | %12s | %+8.1f%%\n", + "baseline", "improved", variance_reduction) + + println("=" ^ 70) + println("📊 Plots saved as PNG files in current directory!") +end + +println("\n✅ Bootstrap analysis completed successfully!") +println("🎯 Summary: LASSO CTMLE demonstrates automatic variable selection with robust performance") +println("📁 Check the generated PNG files for visualization results!") diff --git a/ext/GLMNetExt.jl b/ext/GLMNetExt.jl new file mode 100644 index 00000000..f880f540 --- /dev/null +++ b/ext/GLMNetExt.jl @@ -0,0 +1,20 @@ +module GLMNetExt + + +""" +This file is just a small placeholder so we have a tidy spot to add +GLMNet-specific glue later. + +What to do when I'm ready: +- add `GLMNet` to the environment +- Do `using GLMNet; using TMLE` and TMLE will pick up the extra bits + (Lasso strategy and MLJ wrappers) via conditional loading. + +Leaving this here so the extension structure matches the other `ext/` files. +I can update or remove it whenever you want — keeping it handy for later. +""" + +using GLMNet +using TMLE + +end diff --git a/src/TMLE.jl b/src/TMLE.jl index 75903d82..45e851ed 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -23,6 +23,7 @@ using OrderedCollections using AutoHashEquals using StatisticalMeasures using DataFrames +import GLMNet using ComputationalResources using Base.Threads using Printf @@ -47,6 +48,8 @@ export Configuration export brute_force_ordering, groups_ordering export gradients, epsilons, estimates export AdaptiveCorrelationStrategy, GreedyStrategy +export LassoCTMLE +export GLMNetRegressor, GLMNetClassifier export CausalStratifiedCV, CV, StratifiedCV, Holdout export CPUThreads, CPU1 @@ -69,9 +72,11 @@ include("counterfactual_mean_based/fluctuation.jl") include("counterfactual_mean_based/collaborative_template.jl") include("counterfactual_mean_based/nuisance_estimators.jl") include("counterfactual_mean_based/covariate_based_strategies.jl") +include("counterfactual_mean_based/lasso_strategy.jl") include("counterfactual_mean_based/estimators.jl") include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") +include("counterfactual_mean_based/glmnet-mlj.jl") include("configuration.jl") include("testing.jl") diff --git a/src/counterfactual_mean_based/glmnet-mlj.jl b/src/counterfactual_mean_based/glmnet-mlj.jl new file mode 100644 index 00000000..359bdd92 --- /dev/null +++ b/src/counterfactual_mean_based/glmnet-mlj.jl @@ -0,0 +1,104 @@ +import MLJBase +import GLMNet + +mutable struct GLMNetRegressor <: MLJBase.Deterministic + resampling::MLJBase.ResamplingStrategy + params::Dict +end + +""" + GLMNetRegressor(;resampling=CV(), params...) + +A GLMNet regressor for continuous outcomes based on the `glmnetcv` function from the [GLMNet.jl](https://github.com/JuliaStats/GLMNet.jl) +package. + +# Arguments: + +- resampling: A MLJ `ResamplingStrategy`, see [MLJ resampling strategies](https://alan-turing-institute.github.io/MLJ.jl/dev/evaluating_model_performance/#Built-in-resampling-strategies) +- params: Additional parameters to the `glmnetcv` function + +# Examples: + +A glmnet with `alpha=0`. + +```julia + +model = GLMNetRegressor(resampling=CV(nfolds=3), alpha=0) +mach = machine(model, X, y) +fit!(mach, verbosity=0) +``` +""" +GLMNetRegressor(;resampling=MLJBase.CV(), params...) = GLMNetRegressor(resampling, Dict(params)) + +mutable struct GLMNetClassifier <: MLJBase.Probabilistic + resampling::MLJBase.ResamplingStrategy + params::Dict +end + +""" + GLMNetClassifier(;resampling=StratifiedCV(), params...) + +A GLMNet classifier for binary/multinomial outcomes based on the `glmnetcv` function from the [GLMNet.jl](https://github.com/JuliaStats/GLMNet.jl) +package. + +# Arguments: + +- resampling: A MLJ `ResamplingStrategy`, see [MLJ resampling strategies](https://alan-turing-institute.github.io/MLJ.jl/dev/evaluating_model_performance/#Built-in-resampling-strategies) +- params: Additional parameters to the `glmnetcv` function + +# Examples: + +A glmnet with `alpha=0`. + +```julia + +model = GLMNetClassifier(resampling=StratifiedCV(nfolds=3), alpha=0) +mach = machine(model, X, y) +fit!(mach, verbosity=0) +``` +""" +GLMNetClassifier(;resampling=MLJBase.StratifiedCV(), params...) = GLMNetClassifier(resampling, Dict(params)) + +GLMNetModel = Union{GLMNetRegressor, GLMNetClassifier} + +make_fitresult(::GLMNetRegressor, res, y) = (glmnetcv=res, ) +make_fitresult(::GLMNetClassifier, res, y) = (glmnetcv=res, levels=sort(unique(y))) + +function getfolds(resampling, X, y) + n = size(y, 1) + folds = Vector{Int}(undef, n) + for (split_index, (_, val_indices)) in enumerate(MLJBase.train_test_pairs(resampling, 1:n, X, y)) + folds[val_indices] .= split_index + end + return folds +end + +function MLJBase.fit(model::GLMNetModel, verbosity::Int, X, y) + folds = getfolds(model.resampling, X, y) + res = GLMNet.glmnetcv(MLJBase.matrix(X), y; folds=folds, model.params...) + # This is currently not caught by the GLMNet package + if length(res.meanloss) == 0 + throw(error("glmnetcv's mean loss is empty. Probably meaning convergence failed at the first lambda for some fold.")) + end + return make_fitresult(model, res, y), nothing, nothing +end + +MLJBase.predict(::GLMNetRegressor, fitresult, X) = + GLMNet.predict(fitresult.glmnetcv, MLJBase.matrix(X)) + +function MLJBase.predict(::GLMNetClassifier, fitresult, X) + raw_probs = GLMNet.predict(fitresult.glmnetcv, MLJBase.matrix(X), outtype=:prob) + levels = fitresult.levels + if size(levels, 1) == 2 + probs = hcat(1 .- raw_probs, raw_probs) + preds = MLJBase.UnivariateFinite(levels, probs, pool=missing) + else + preds = MLJBase.UnivariateFinite(levels, raw_probs, pool=missing) + end + return preds +end + +MLJBase.input_scitype(::Type{<:GLMNetModel}) = MLJBase.Table{<:AbstractVector{<:MLJBase.Continuous}} +MLJBase.target_scitype(::Type{<:GLMNetRegressor}) = AbstractVector{<:MLJBase.Continuous} +MLJBase.target_scitype(::Type{<:GLMNetClassifier}) = AbstractVector{<:MLJBase.Finite} + diff --git a/src/counterfactual_mean_based/lasso_strategy.jl b/src/counterfactual_mean_based/lasso_strategy.jl new file mode 100644 index 00000000..705efb3b --- /dev/null +++ b/src/counterfactual_mean_based/lasso_strategy.jl @@ -0,0 +1,180 @@ +import GLMNet + +""" + LassoCTMLE <: CollaborativeStrategy + +LASSO-based Collaborative TMLE strategy for high-dimensional causal inference. + +# Notes +- Confounders are automatically extracted from the provided `estimand` at runtime + (via `extract_confounders_from_estimand(Ψ)`). The constructor no longer requires + an explicit `confounders` argument; callers may still build custom propensity + specifications by calling `propensity_score(Ψ, confounders_list, strategy)`. +- Uses GLMNet cross-validation to select the optimal lambda automatically. +- No refitting: coefficients from the CV fit are reused directly for efficiency. + +# Parameters +- `cv_folds`: Number of cross-validation folds for lambda selection +- `alpha`: Elastic Net mixing parameter (1.0 = LASSO, 0.0 = Ridge) + +# Example +```julia +strategy = LassoCTMLE(cv_folds = 5, alpha = 1.0) +estimator = Tmle(collaborative_strategy = strategy) +result, _ = estimator(estimand, data) +``` +""" +mutable struct LassoCTMLE <: CollaborativeStrategy + cv_folds::Int + alpha::Float64 + initial_fit::Any + used::Bool + + function LassoCTMLE(; + cv_folds = 5, + alpha = 1.0 + ) + new(cv_folds, alpha, nothing, false) + end +end + +# helper for conditional logs +log_info(strategy::LassoCTMLE, msg) = @debug msg + +function fit_glmnet_propensity_score(var_names, strategy::LassoCTMLE) + # use CV-selected optimal lambda from the stored fit + if strategy.initial_fit === nothing + throw(ErrorException("LassoCTMLE requires a GLMNet CV fit stored in `strategy.initial_fit`. Ensure the strategy has been initialized.")) + end + + cv_fit = strategy.initial_fit + path = cv_fit.path + + # use optimal lambda from CV (minimum mean loss) + optimal_lambda_idx = argmin(cv_fit.meanloss) + optimal_lambda = cv_fit.lambda[optimal_lambda_idx] + idx = optimal_lambda_idx + + coeffs = path.betas[:, idx] + selected_indices = findall(x -> abs(x) > 1e-6, coeffs) + + if isempty(selected_indices) + @warn "No variables selected by GLMNet at optimal λ=$optimal_lambda, using all variables" + return var_names, cv_fit, idx + end + + selected_vars = var_names[selected_indices] + return selected_vars, cv_fit, idx +end + +""" +Extract a vector of confounder symbols from the estimand `Ψ`. +Collects treatment-specific confounders (in order) and returns unique symbols. +""" +function extract_confounders_from_estimand(Ψ) + Ψtreatments = TMLE.treatments(Ψ) + all = Symbol[] + for T in Ψtreatments + if hasproperty(Ψ, :treatment_confounders) && haskey(Ψ.treatment_confounders, T) + append!(all, collect(Ψ.treatment_confounders[T])) + end + end + return unique(all) +end + +function initialise!(strategy::LassoCTMLE, Ψ) + strategy.used = false + return nothing +end + +update!(strategy::LassoCTMLE, g, ĝ) = nothing + +finalise!(strategy::LassoCTMLE) = nothing + +function exhausted(strategy::LassoCTMLE) + # strategy runs once with CV-optimal lambda, then is exhausted + return strategy.used +end + +""" +Create propensity score specification using the given confounders list. +""" +function propensity_score(Ψ, confounders_list::Vector{Symbol}, strategy::LassoCTMLE) + Ψtreatments = TMLE.treatments(Ψ) + return Tuple(map(eachindex(Ψtreatments)) do index + T = Ψtreatments[index] + T_confounders = intersect(confounders_list, Ψ.treatment_confounders[T]) + T_parents = (T_confounders..., Ψtreatments[index+1:end]...) + TMLE.ConditionalDistribution(T, T_parents) + end) +end + +""" +Get propensity score specification from the collaborative strategy. +""" +function propensity_score(Ψ, strategy::LassoCTMLE) + confounders = extract_confounders_from_estimand(Ψ) + return propensity_score(Ψ, confounders, strategy) +end + +""" +Iterator implementation for LASSO-based collaborative TMLE. +Runs once with GLMNet CV-selected optimal lambda. +""" +function Base.iterate(it::TMLE.StepKPropensityScoreIterator{LassoCTMLE}) + strategy = it.collaborative_strategy + + # only run once + if strategy.used + return nothing + end + + # extract confounders once (used throughout) + confounders = extract_confounders_from_estimand(it.Ψ) + + # run GLMNet CV if not already done + if strategy.initial_fit === nothing + treatment_var = first(TMLE.treatments(it.Ψ)) + y_binary = Int.(unwrap.(it.dataset[!, treatment_var])) + confounder_data = it.dataset[!, confounders] + X_matrix = Matrix{Float64}(confounder_data) + + # run CV to get optimal lambda + strategy.initial_fit = GLMNet.glmnetcv(X_matrix, y_binary, alpha=strategy.alpha, nfolds=strategy.cv_folds) + # find lambda with minimum CV loss + optimal_lambda_idx = argmin(strategy.initial_fit.meanloss) + optimal_lambda = strategy.initial_fit.lambda[optimal_lambda_idx] + log_info(strategy, "LassoCTMLE: CV selected λ=$optimal_lambda") + end + + # get variable selection from CV fit + selected_confounders, glm_fit, lambda_idx = fit_glmnet_propensity_score(confounders, strategy) + + log_info(strategy, "LassoCTMLE: Selected $(length(selected_confounders))/$(length(confounders)) confounders") + + # build propensity score specification + g = propensity_score(it.Ψ, selected_confounders, strategy) + + # build prefit estimator using CV coefficients (no refitting) + path = glm_fit.path + selected_indices = [findfirst(==(v), confounders) for v in selected_confounders] + coeffs_full = path.betas[:, lambda_idx] + coeffs_selected = coeffs_full[selected_indices] + intercept = path.a0[lambda_idx] + + components = Dict{Symbol, Tuple}() + for cd in g + components[cd.outcome] = (selected_confounders, coeffs_selected, intercept) + end + ĝ = TMLE.PrefitGLMNetJointConditionalDistributionEstimator(components) + + strategy.used = true + + # return optimal lambda (computed once above or from cached fit) + optimal_lambda_idx = argmin(glm_fit.meanloss) + optimal_lambda = glm_fit.lambda[optimal_lambda_idx] + + return (g, ĝ), optimal_lambda +end + +Base.iterate(it::TMLE.StepKPropensityScoreIterator{LassoCTMLE}, state) = nothing diff --git a/src/estimates.jl b/src/estimates.jl index 21761401..a1b7885a 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -31,6 +31,35 @@ function MLJBase.predict(estimate::MLConditionalDistribution, dataset) return predict(estimate.machine, X) end +##################################################################### +### PrefitGLMNetConditionalDistribution ### +##################################################################### + +""" +Holds a precomputed GLMNet estimate that uses stored coefficients without refitting. +Used internally by LassoCTMLE strategy to avoid refitting per lambda candidate. +""" +struct PrefitGLMNetConditionalDistribution <: Estimate + estimand::ConditionalDistribution + varnames::Vector{Symbol} + coeffs::Vector{Float64} + intercept::Float64 +end + +string_repr(estimate::PrefitGLMNetConditionalDistribution) = + string("P̂(", estimate.estimand.outcome, " | ", join(estimate.estimand.parents, ", "), + "), prefit GLMNet with ", length(estimate.varnames), " variables") + +function MLJBase.predict(estimate::PrefitGLMNetConditionalDistribution, dataset) + X = selectcols(dataset, estimate.varnames) + Xmat = Matrix{Float64}(X) + η = estimate.intercept .+ Xmat * estimate.coeffs + ps = 1 ./ (1 .+ exp.(-η)) + # Return UnivariateFinite for binary outcomes (compatible with categorical treatment) + outcome_levels = [0, 1] + probs = hcat(1 .- ps, ps) + return MLJBase.UnivariateFinite(outcome_levels, probs, pool=missing) +end ##################################################################### ### SampleSplitMLConditionalDistribution ### @@ -123,16 +152,16 @@ end ### ConditionalDistributionEstimate ### ##################################################################### -ConditionalDistributionEstimate = Union{MLConditionalDistribution, SampleSplitMLConditionalDistribution} +ConditionalDistributionEstimate = Union{MLConditionalDistribution, SampleSplitMLConditionalDistribution, PrefitGLMNetConditionalDistribution} function expected_value(estimate::ConditionalDistributionEstimate, dataset) return expected_value(predict(estimate, dataset)) end function likelihood(estimate::ConditionalDistributionEstimate, dataset) - ŷ = predict(estimate, dataset) + ŷ = predict(estimate, dataset) y = dataset[!, estimate.estimand.outcome] - return pdf.(ŷ, y) + return pdf.(ŷ, y) end function compute_offset(ŷ::AbstractVector{<:UnivariateFinite{<:Union{OrderedFactor{2}, Multiclass{2}}}}) diff --git a/src/estimators.jl b/src/estimators.jl index ef3cdd88..15909a21 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -181,6 +181,32 @@ ConditionalDistributionEstimator(model, train_validation_indices::AbstractVector cd_estimators::Dict{Symbol, Any} end +##################################################################### +### PrefitGLMNetJointConditionalDistributionEstimator ### +##################################################################### + +""" +Estimator that returns prefit GLMNet estimates without refitting. +Used internally by LassoCTMLE to avoid refitting per lambda. +""" +@auto_hash_equals struct PrefitGLMNetJointConditionalDistributionEstimator <: Estimator + components::Dict{Symbol, Tuple} # outcome => (varnames, coeffs, intercept) +end + +function (estimator::PrefitGLMNetJointConditionalDistributionEstimator)(conditional_distributions, dataset; + cache=Dict(), + verbosity=1, + machine_cache=false, + acceleration=CPU1() + ) + estimates = map(conditional_distributions) do cd + outcome = cd.outcome + varnames, coeffs, intercept = estimator.components[outcome] + PrefitGLMNetConditionalDistribution(cd, varnames, coeffs, intercept) + end + return JointConditionalDistributionEstimate(conditional_distributions, Tuple(estimates)) +end + function fit_conditional_distributions(acceleration::CPU1, cd_estimators, conditional_distributions, dataset; cache=Dict(), verbosity=1, machine_cache=false) return map(conditional_distributions) do conditional_distribution cd_estimator = cd_estimators[conditional_distribution.outcome] diff --git a/test/counterfactual_mean_based/lasso_strategy.jl b/test/counterfactual_mean_based/lasso_strategy.jl new file mode 100644 index 00000000..b9edde42 --- /dev/null +++ b/test/counterfactual_mean_based/lasso_strategy.jl @@ -0,0 +1,51 @@ +using Test +using TMLE +using DataFrames +using CategoricalArrays +using Random + +function create_test_data(n=100, p=5) + Random.seed!(123) + W_df = DataFrame(Dict(Symbol("W$i") => randn(n) for i in 1:p)) + A_vals = rand([0, 1], n) + A = categorical(A_vals; levels=[0, 1]) + Y = 2.0 * A_vals + sum(Matrix(W_df[:, 1:3]), dims=2)[:, 1] + randn(n) * 0.3 + return hcat(W_df, DataFrame(A=A, Y=Y)) +end + +@testset "LASSO Collaborative TMLE" begin + + @testset "Basic construction and defaults" begin + strategy = LassoCTMLE() + @test strategy.cv_folds == 5 + @test strategy.alpha == 1.0 + @test strategy.initial_fit === nothing + @test strategy.used == false + end + + @testset "LASSO CTMLE with automatic CV lambda selection" begin + dataset = create_test_data(150, 8) + confounders = [Symbol("W$i") for i in 1:8] + + estimand = ATE( + outcome = :Y, + treatment_values = (A = (case = 1, control = 0),), + treatment_confounders = (A = confounders,) + ) + + # Test LASSO CTMLE with default settings (automatic CV lambda) + lasso_strategy = LassoCTMLE() + lasso_estimator = Tmle(collaborative_strategy = lasso_strategy) + lasso_result, _ = lasso_estimator(estimand, dataset; verbosity = 0) + + @test !isnan(estimate(lasso_result)) + + # Compare with standard TMLE to ensure regularization works + standard_estimator = Tmle() + standard_result, _ = standard_estimator(estimand, dataset; verbosity = 0) + + @test !isnan(estimate(standard_result)) + @test estimate(lasso_result) != estimate(standard_result) + end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 9934e7d6..f9e56754 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,38 +1,38 @@ using Test using TMLE -TEST_DIR = joinpath(pkgdir(TMLE), "test") +const TEST_DIR = joinpath(pkgdir(TMLE), "test") +@testset "TMLE.jl" begin + # Test general functionality + include("utils.jl") + include("scm.jl") + include("adjustment.jl") + include("estimands.jl") + include("estimators_and_estimates.jl") + include("missing_management.jl") + include("composition.jl") + include("resampling.jl") -@time begin - # Test general functionality - @test include(joinpath(TEST_DIR, "utils.jl")) - @test include(joinpath(TEST_DIR, "scm.jl")) - @test include(joinpath(TEST_DIR, "adjustment.jl")) - @test include(joinpath(TEST_DIR, "estimands.jl")) - @test include(joinpath(TEST_DIR, "estimators_and_estimates.jl")) - @test include(joinpath(TEST_DIR, "missing_management.jl")) - @test include(joinpath(TEST_DIR, "composition.jl")) - @test include(joinpath(TEST_DIR, "resampling.jl")) - - # Test Counterfactual Mean Based Estimation - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/estimands.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/clever_covariate.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/gradient.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/fluctuation.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/estimators_and_estimates.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/non_regression_test.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/double_robustness_ate.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/double_robustness_aie.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/3points_interactions.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/collaborative_template.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/covariate_based_strategies.jl")) - - # Test Extensions - if VERSION >= v"1.9" - @test include(joinpath(TEST_DIR, "configuration.jl")) - @test include(joinpath(TEST_DIR, "causaltables_interface.jl")) - end + # Test Counterfactual Mean Based Estimation + include("counterfactual_mean_based/estimands.jl") + include("counterfactual_mean_based/clever_covariate.jl") + include("counterfactual_mean_based/gradient.jl") + include("counterfactual_mean_based/fluctuation.jl") + include("counterfactual_mean_based/estimators_and_estimates.jl") + include("counterfactual_mean_based/non_regression_test.jl") + include("counterfactual_mean_based/double_robustness_ate.jl") + include("counterfactual_mean_based/double_robustness_aie.jl") + include("counterfactual_mean_based/3points_interactions.jl") + include("counterfactual_mean_based/collaborative_template.jl") + include("counterfactual_mean_based/covariate_based_strategies.jl") + include("counterfactual_mean_based/lasso_strategy.jl") - # Test Experimental - @test include(joinpath(TEST_DIR, "estimand_ordering.jl")) -end \ No newline at end of file + # Test Extensions + if VERSION >= v"1.9" + include("configuration.jl") + include("causaltables_interface.jl") + end + + # Test Experimental + include("estimand_ordering.jl") +end