diff --git a/.github/workflows/apo_sim.yml b/.github/workflows/apo_sim.yml index 1cbdc3f4..fe5b12d7 100644 --- a/.github/workflows/apo_sim.yml +++ b/.github/workflows/apo_sim.yml @@ -18,6 +18,7 @@ jobs: matrix: script: [ 'scripts/irm/irm_apo_coverage.py', + 'scripts/irm/irm_apo_ate_sensitivity.py', ] steps: diff --git a/doc/irm/apo.qmd b/doc/irm/apo.qmd index 245d29e3..befb83ad 100644 --- a/doc/irm/apo.qmd +++ b/doc/irm/apo.qmd @@ -246,4 +246,113 @@ make_pretty(df_ate_95, level, n_rep) level = 0.9 df_ate_9 = df[df['level'] == level][display_columns] make_pretty(df_ate_9, level, n_rep) -``` \ No newline at end of file +``` + + +## Causal Contrast Sensitivity + +The simulations are based on the the ADD-DGP with $10,000$ observations. As the DGP is nonlinear, we will only use corresponding learners. Since the DGP includes an unobserved confounder, we would expect a bias in the ATE estimates, leading to low coverage of the true parameter. + +The confounding is set such that both sensitivity parameters are approximately $cf_y=cf_d=0.1$, such that the robustness value $RV$ should be approximately $10\%$. +Further, the corresponding confidence intervals are one-sided (since the direction of the bias is unkown), such that only one side should approximate the corresponding coverage level (here only the lower coverage is relevant since the bias is positive). Remark that for the coverage level the value of $\rho$ has to be correctly specified, such that the coverage level will be generally (significantly) larger than the nominal level under the conservative choice of $|\rho|=1$. + +```{python} +#| echo: false + +import numpy as np +import pandas as pd +from itables import init_notebook_mode, show, options + +init_notebook_mode(all_interactive=True) + +def highlight_range(s, level=0.95, dist=0.05, props=''): + color_grid = np.where((s >= level-dist) & + (s <= level+dist), props, '') + return color_grid + + +def color_coverage(df, level): + # color coverage column order is important + styled_df = df.apply( + highlight_range, + level=level, + dist=1.0, + props='color:black;background-color:red', + subset=["Coverage", "Coverage (Lower)"]) + styled_df = styled_df.apply( + highlight_range, + level=level, + dist=0.1, + props='color:black;background-color:yellow', + subset=["Coverage", "Coverage (Lower)"]) + styled_df = styled_df.apply( + highlight_range, + level=level, + dist=0.05, + props='color:white;background-color:darkgreen', + subset=["Coverage", "Coverage (Lower)"]) + + # set all coverage values to bold + styled_df = styled_df.set_properties( + **{'font-weight': 'bold'}, + subset=["Coverage", "Coverage (Lower)"]) + return styled_df + + +def make_pretty(df, level, n_rep): + styled_df = df.style.hide(axis="index") + # Format only float columns + float_cols = df.select_dtypes(include=['float']).columns + styled_df = styled_df.format({col: "{:.3f}" for col in float_cols}) + + # color coverage column order is important + styled_df = color_coverage(styled_df, level) + caption = f"Coverage for {level*100}%-Confidence Interval over {n_rep} Repetitions" + + return show(styled_df, caption=caption) +``` + +### ATE + +::: {.callout-note title="Metadata" collapse="true"} + +```{python} +#| echo: false +metadata_file = '../../results/irm/irm_apo_sensitivity_metadata.csv' +metadata_df = pd.read_csv(metadata_file) +print(metadata_df.T.to_string(header=False)) +``` + +::: + +```{python} +#| echo: false + +# set up data and rename columns +df = pd.read_csv("../../results/irm/irm_apo_sensitivity.csv", index_col=None) + +assert df["repetition"].nunique() == 1 +n_rep = df["repetition"].unique()[0] + +display_columns = [ + "Learner g", "Learner m", "Bias", "Bias (Lower)", "Bias (Upper)", "Coverage", "Coverage (Lower)", "Coverage (Upper)", "RV", "RVa"] +``` + +```{python} +#| echo: false +level = 0.95 + +df_ate_95 = df[(df['level'] == level)][display_columns] +df_ate_95.rename(columns={"Learner g": "Learner l"}, inplace=True) +make_pretty(df_ate_95, level, n_rep) +``` + +```{python} +#| echo: false +level = 0.9 + +df_ate_9 = df[(df['level'] == level)][display_columns] +df_ate_9.rename(columns={"Learner g": "Learner l"}, inplace=True) +make_pretty(df_ate_9, level, n_rep) +``` + diff --git a/doc/irm/irm.qmd b/doc/irm/irm.qmd index 0425f326..0dd12991 100644 --- a/doc/irm/irm.qmd +++ b/doc/irm/irm.qmd @@ -242,7 +242,6 @@ display_columns = [ ```{python} #| echo: false -score = "partialling out" level = 0.95 df_ate_95 = df[(df['level'] == level)][display_columns] @@ -252,7 +251,6 @@ make_pretty(df_ate_95, level, n_rep) ```{python} #| echo: false -score = "partialling out" level = 0.9 df_ate_9 = df[(df['level'] == level)][display_columns] diff --git a/requirements.txt b/requirements.txt index 2accaaba..e4b853e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ doubleml[rdd] joblib numpy pandas -scikit-learn +scikit-learn==1.5.2 lightgbm itables ipykernel \ No newline at end of file diff --git a/results/irm/irm_apo_coverage_apo.csv b/results/irm/irm_apo_coverage_apo.csv index c4e88516..8ee65bbc 100644 --- a/results/irm/irm_apo_coverage_apo.csv +++ b/results/irm/irm_apo_coverage_apo.csv @@ -1,25 +1,25 @@ Learner g,Learner m,Treatment Level,level,Coverage,CI Length,Bias,repetition -LGBM,LGBM,0.0,0.9,0.911,8.657690136121921,2.0760918678352267,1000 -LGBM,LGBM,0.0,0.95,0.963,10.316274091547035,2.0760918678352267,1000 -LGBM,LGBM,1.0,0.9,0.914,38.23339285821166,9.19332007236871,1000 -LGBM,LGBM,1.0,0.95,0.967,45.55789754237906,9.19332007236871,1000 -LGBM,LGBM,2.0,0.9,0.891,37.49194764946096,9.594147407062762,1000 -LGBM,LGBM,2.0,0.95,0.952,44.67441108385784,9.594147407062762,1000 -LGBM,Logistic,0.0,0.9,0.901,5.625897101886533,1.3366474407456235,1000 -LGBM,Logistic,0.0,0.95,0.955,6.7036698705295725,1.3366474407456235,1000 -LGBM,Logistic,1.0,0.9,0.92,7.423300143143785,1.6954254817667072,1000 -LGBM,Logistic,1.0,0.95,0.968,8.84540769378873,1.6954254817667072,1000 -LGBM,Logistic,2.0,0.9,0.918,7.321275660150268,1.660607504960853,1000 -LGBM,Logistic,2.0,0.95,0.971,8.723838024042964,1.660607504960853,1000 -Linear,LGBM,0.0,0.9,0.901,5.498257423071024,1.307400569952483,1000 -Linear,LGBM,0.0,0.95,0.949,6.551577812380716,1.307400569952483,1000 -Linear,LGBM,1.0,0.9,0.951,10.700720020780512,2.1270559074131152,1000 -Linear,LGBM,1.0,0.95,0.982,12.75069435099058,2.1270559074131152,1000 -Linear,LGBM,2.0,0.9,0.932,7.513644049429104,1.634643965398859,1000 -Linear,LGBM,2.0,0.95,0.967,8.953059097926168,1.634643965398859,1000 -Linear,Logistic,0.0,0.9,0.9,5.335670092717667,1.2859337252007816,1000 -Linear,Logistic,0.0,0.95,0.951,6.357843058957276,1.2859337252007816,1000 -Linear,Logistic,1.0,0.9,0.904,5.417512107920403,1.2802845007629777,1000 -Linear,Logistic,1.0,0.95,0.954,6.455363835025866,1.2802845007629777,1000 -Linear,Logistic,2.0,0.9,0.905,5.366403391397197,1.2808554104576877,1000 -Linear,Logistic,2.0,0.95,0.959,6.3944640430685675,1.2808554104576877,1000 +LGBM,LGBM,0.0,0.9,0.91,8.657690136121921,2.077166476271961,1000 +LGBM,LGBM,0.0,0.95,0.966,10.316274091547035,2.077166476271961,1000 +LGBM,LGBM,1.0,0.9,0.914,38.23339285821166,9.216990722545221,1000 +LGBM,LGBM,1.0,0.95,0.968,45.55789754237906,9.216990722545221,1000 +LGBM,LGBM,2.0,0.9,0.894,37.49194764946096,9.572829064615373,1000 +LGBM,LGBM,2.0,0.95,0.953,44.67441108385784,9.572829064615373,1000 +LGBM,Logistic,0.0,0.9,0.904,5.625897101886533,1.341832646808728,1000 +LGBM,Logistic,0.0,0.95,0.955,6.7036698705295725,1.341832646808728,1000 +LGBM,Logistic,1.0,0.9,0.924,7.423300143143785,1.6934000724558478,1000 +LGBM,Logistic,1.0,0.95,0.969,8.84540769378873,1.6934000724558478,1000 +LGBM,Logistic,2.0,0.9,0.92,7.321275660150268,1.6623515097372739,1000 +LGBM,Logistic,2.0,0.95,0.965,8.723838024042964,1.6623515097372739,1000 +Linear,LGBM,0.0,0.9,0.9,5.498257423071024,1.3124516734608647,1000 +Linear,LGBM,0.0,0.95,0.95,6.551577812380716,1.3124516734608647,1000 +Linear,LGBM,1.0,0.9,0.948,10.700720020780512,2.1292560240454224,1000 +Linear,LGBM,1.0,0.95,0.983,12.75069435099058,2.1292560240454224,1000 +Linear,LGBM,2.0,0.9,0.934,7.513644049429104,1.6375362733403969,1000 +Linear,LGBM,2.0,0.95,0.966,8.953059097926168,1.6375362733403969,1000 +Linear,Logistic,0.0,0.9,0.905,5.335670092717667,1.2915169940302822,1000 +Linear,Logistic,0.0,0.95,0.955,6.357843058957276,1.2915169940302822,1000 +Linear,Logistic,1.0,0.9,0.908,5.417512107920403,1.2804808010829234,1000 +Linear,Logistic,1.0,0.95,0.957,6.455363835025866,1.2804808010829234,1000 +Linear,Logistic,2.0,0.9,0.906,5.366403391397197,1.2873523314776736,1000 +Linear,Logistic,2.0,0.95,0.957,6.3944640430685675,1.2873523314776736,1000 diff --git a/results/irm/irm_apo_coverage_apos.csv b/results/irm/irm_apo_coverage_apos.csv index 9450b8bf..9ba88abf 100644 --- a/results/irm/irm_apo_coverage_apos.csv +++ b/results/irm/irm_apo_coverage_apos.csv @@ -1,9 +1,9 @@ Learner g,Learner m,level,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,repetition -LGBM,LGBM,0.9,0.9063333333333333,28.21021843292038,7.1182934844676025,0.925,36.1943517986729,1000 -LGBM,LGBM,0.95,0.9583333333333334,33.61454856442563,7.1182934844676025,0.973,40.868020474452145,1000 -LGBM,Logistic,0.9,0.9153333333333333,6.789316986971467,1.5714825799469632,0.924,8.394070597867524,1000 -LGBM,Logistic,0.95,0.963,8.089970168806186,1.5714825799469632,0.959,9.592284619044477,1000 -Linear,LGBM,0.9,0.9266666666666666,7.903418354805325,1.726298806364679,0.939,9.850445774370156,1000 -Linear,LGBM,0.95,0.967,9.41750382912841,1.726298806364679,0.974,11.234392064837424,1000 -Linear,Logistic,0.9,0.904,5.372532806955153,1.2803595601237736,0.906,5.7964894175017925,1000 -Linear,Logistic,0.95,0.957,6.4017676921854445,1.2803595601237736,0.953,6.8182385960659,1000 +LGBM,LGBM,0.9,0.9063333333333333,28.21021843292038,7.117854990409516,0.926,36.1943517986729,1000 +LGBM,LGBM,0.95,0.958,33.61454856442563,7.117854990409516,0.973,40.868020474452145,1000 +LGBM,Logistic,0.9,0.9126666666666666,6.789316986971467,1.5740024109154114,0.922,8.394070597867524,1000 +LGBM,Logistic,0.95,0.962,8.089970168806186,1.5740024109154114,0.959,9.592284619044477,1000 +Linear,LGBM,0.9,0.927,7.903418354805325,1.730081847384891,0.936,9.850445774370156,1000 +Linear,LGBM,0.95,0.968,9.41750382912841,1.730081847384891,0.974,11.234392064837424,1000 +Linear,Logistic,0.9,0.9053333333333333,5.372532806955153,1.2840405983352503,0.901,5.7964894175017925,1000 +Linear,Logistic,0.95,0.9556666666666667,6.4017676921854445,1.2840405983352503,0.952,6.8182385960659,1000 diff --git a/results/irm/irm_apo_coverage_apos_contrast.csv b/results/irm/irm_apo_coverage_apos_contrast.csv index c663a0b4..0c2215ca 100644 --- a/results/irm/irm_apo_coverage_apos_contrast.csv +++ b/results/irm/irm_apo_coverage_apos_contrast.csv @@ -1,9 +1,9 @@ Learner g,Learner m,level,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,repetition -LGBM,LGBM,0.9,0.8885,37.87536523730769,9.786356896553793,0.898,44.828925781533165,1000 -LGBM,LGBM,0.95,0.9485,45.13128131893863,9.786356896553793,0.965,51.42387403222799,1000 -LGBM,Logistic,0.9,0.9275,5.725128733165455,1.2679354841429349,0.926,6.774304315459575,1000 -LGBM,Logistic,0.95,0.962,6.821911652197591,1.2679354841429349,0.961,7.7654094416927215,1000 -Linear,LGBM,0.9,0.958,7.430953406859927,1.5066894086042675,0.975,8.798628447594016,1000 -Linear,LGBM,0.95,0.9885,8.85452711998085,1.5066894086042675,0.992,10.090116294778499,1000 -Linear,Logistic,0.9,0.873,1.1425883251377777,0.29743753232068365,0.869,1.3505222299078565,1000 -Linear,Logistic,0.95,0.9335,1.3614779635902856,0.29743753232068365,0.916,1.5496710496635275,1000 +LGBM,LGBM,0.9,0.8885,37.87536523730769,9.786311172069826,0.898,44.828925781533165,1000 +LGBM,LGBM,0.95,0.9485,45.13128131893863,9.786311172069826,0.965,51.42387403222799,1000 +LGBM,Logistic,0.9,0.9275,5.725128733165455,1.2679687960356285,0.926,6.774304315459575,1000 +LGBM,Logistic,0.95,0.962,6.821911652197591,1.2679687960356285,0.961,7.7654094416927215,1000 +Linear,LGBM,0.9,0.958,7.430953406859927,1.506712453445455,0.975,8.798628447594016,1000 +Linear,LGBM,0.95,0.9885,8.85452711998085,1.506712453445455,0.992,10.090116294778499,1000 +Linear,Logistic,0.9,0.873,1.1425883251377777,0.2976271580660105,0.871,1.3505222299078565,1000 +Linear,Logistic,0.95,0.933,1.3614779635902856,0.2976271580660105,0.916,1.5496710496635275,1000 diff --git a/results/irm/irm_apo_coverage_metadata.csv b/results/irm/irm_apo_coverage_metadata.csv index 5fac53ab..46811009 100644 --- a/results/irm/irm_apo_coverage_metadata.csv +++ b/results/irm/irm_apo_coverage_metadata.csv @@ -1,2 +1,2 @@ -DoubleML Version,Script,Date,Total Runtime (seconds),Python Version -0.10.dev0,irm_apo_coverage.py,2025-01-08 15:02:49,10054.695460557938,3.12.8 +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Number of observations,Number of repetitions +0.10.dev0,irm_apo_coverage.py,2025-02-17 14:15:48,5805.248526096344,3.12.9,500,1000 diff --git a/results/irm/irm_apo_sensitivity.csv b/results/irm/irm_apo_sensitivity.csv new file mode 100644 index 00000000..ef046c96 --- /dev/null +++ b/results/irm/irm_apo_sensitivity.csv @@ -0,0 +1,9 @@ +Learner g,Learner m,level,Coverage,CI Length,Bias,Coverage (Lower),Coverage (Upper),RV,RVa,Bias (Lower),Bias (Upper),CI Bound Length,repetition +LGBM,LGBM,0.9,0.0,0.1850901948571434,0.1728687505886857,0.94,1.0,0.11948035202564793,0.07154674763559875,0.031810817504062554,0.31670282436556935,0.431995754637111,100 +LGBM,LGBM,0.95,0.0,0.2205485703210263,0.1728687505886857,0.98,1.0,0.11948035202564793,0.05745432815748635,0.031810817504062554,0.31670282436556935,0.4729106257831797,100 +LGBM,Logistic Regr.,0.9,0.0,0.18131193168905674,0.15121724930885022,0.99,1.0,0.10434537281186201,0.05696357062448621,0.021088484865101825,0.29646127759826496,0.43179145449456513,100 +LGBM,Logistic Regr.,0.95,0.03,0.2160464920739249,0.15121724930885022,1.0,1.0,0.10434537281186201,0.04307260925250725,0.021088484865101825,0.29646127759826496,0.47184900436363103,100 +Linear Reg.,LGBM,0.9,0.0,0.18663164735761398,0.17472687461915876,0.92,1.0,0.12056131194284912,0.07240650957100214,0.03257050661484845,0.31875766896448493,0.4335837394031654,100 +Linear Reg.,LGBM,0.95,0.0,0.2223853242639293,0.17472687461915876,0.99,1.0,0.12056131194284912,0.05822924694927436,0.03257050661484845,0.31875766896448493,0.47483724706967173,100 +Linear Reg.,Logistic Regr.,0.9,0.56,0.1828483122747645,0.08933864685804803,1.0,1.0,0.06284434571148402,0.015959822013915787,0.05670070892745642,0.23481634587221883,0.43345409695347575,100 +Linear Reg.,Logistic Regr.,0.95,0.71,0.21787720245762926,0.08933864685804803,1.0,1.0,0.06284434571148402,0.007943564313924893,0.05670070892745642,0.23481634587221883,0.4738504980497505,100 diff --git a/results/irm/irm_apo_sensitivity_metadata.csv b/results/irm/irm_apo_sensitivity_metadata.csv new file mode 100644 index 00000000..6900b314 --- /dev/null +++ b/results/irm/irm_apo_sensitivity_metadata.csv @@ -0,0 +1,2 @@ +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Sensitivity Errors,Number of observations,Number of repetitions +0.10.dev0,irm_apo_sensitivity.py,2025-02-17 15:18:22,9570.017815113068,3.12.9,0,10000,100 diff --git a/results/irm/irm_ate_coverage.csv b/results/irm/irm_ate_coverage.csv index 0de042e0..205cc835 100644 --- a/results/irm/irm_ate_coverage.csv +++ b/results/irm/irm_ate_coverage.csv @@ -1,9 +1,9 @@ Learner g,Learner m,level,Coverage,CI Length,Bias,repetition -Lasso,Logistic Regression,0.9,0.875,0.467771360497192,0.12336884150536219,1000 -Lasso,Logistic Regression,0.95,0.935,0.5573839547492132,0.12336884150536219,1000 -Lasso,Random Forest,0.9,0.906,0.6047145364621751,0.1478860899164366,1000 -Lasso,Random Forest,0.95,0.959,0.7205618135094181,0.1478860899164366,1000 -Random Forest,Logistic Regression,0.9,0.803,0.5176227656141883,0.15019297474976653,1000 -Random Forest,Logistic Regression,0.95,0.883,0.6167855677602845,0.15019297474976653,1000 -Random Forest,Random Forest,0.9,0.899,0.6330777189487165,0.15360554047241584,1000 -Random Forest,Random Forest,0.95,0.959,0.7543586299857805,0.15360554047241584,1000 +Lasso,Logistic Regression,0.9,0.875,0.467771360497192,0.12336884150536215,1000 +Lasso,Logistic Regression,0.95,0.935,0.5573839547492131,0.12336884150536215,1000 +Lasso,Random Forest,0.9,0.918,0.6075675690766795,0.14883418949014868,1000 +Lasso,Random Forest,0.95,0.96,0.7239614115523821,0.14883418949014868,1000 +Random Forest,Logistic Regression,0.9,0.794,0.5179181614571521,0.15026458264746098,1000 +Random Forest,Logistic Regression,0.95,0.886,0.6171375536172053,0.15026458264746098,1000 +Random Forest,Random Forest,0.9,0.887,0.6267427414652039,0.1523909615423744,1000 +Random Forest,Random Forest,0.95,0.951,0.7468100387268922,0.1523909615423744,1000 diff --git a/results/irm/irm_ate_coverage_metadata.csv b/results/irm/irm_ate_coverage_metadata.csv index a6fbf57c..ebc06662 100644 --- a/results/irm/irm_ate_coverage_metadata.csv +++ b/results/irm/irm_ate_coverage_metadata.csv @@ -1,2 +1,2 @@ -DoubleML Version,Script,Date,Total Runtime (seconds),Python Version -0.10.dev0,irm_ate_coverage.py,2025-01-08 13:14:33,3543.555382013321,3.12.8 +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Number of observations,Number of repetitions +0.10.dev0,irm_ate_coverage.py,2025-02-17 13:38:59,3587.198892593384,3.12.9,500,1000 diff --git a/results/irm/irm_ate_sensitivity.csv b/results/irm/irm_ate_sensitivity.csv index 6e107706..b81d88be 100644 --- a/results/irm/irm_ate_sensitivity.csv +++ b/results/irm/irm_ate_sensitivity.csv @@ -1,9 +1,9 @@ -Learner g,Learner m,level,Coverage,CI Length,Bias,Coverage (Lower),Coverage (Upper),RV,RVa,Bias (Lower),Bias (Upper),repetition -LGBM,LGBM,0.9,0.112,0.266748233354866,0.17891290135375168,0.962,1.0,0.12379347892727971,0.05409589192160397,0.04254708028278409,0.32210978560617337,500 -LGBM,LGBM,0.95,0.318,0.31785012462427936,0.17891290135375168,0.998,1.0,0.12379347892727971,0.03441021667548556,0.04254708028278409,0.32210978560617337,500 -LGBM,Logistic Regr.,0.9,0.292,0.2577778025822409,0.14922926552528684,1.0,1.0,0.10066571951295798,0.03493291437943745,0.029012990398602386,0.2979424530565633,500 -LGBM,Logistic Regr.,0.95,0.548,0.30716119707955875,0.14922926552528684,1.0,1.0,0.10066571951295798,0.01869752301454861,0.029012990398602386,0.2979424530565633,500 -Linear Reg.,LGBM,0.9,0.122,0.2675665174758639,0.17873104426193565,0.964,1.0,0.12647219547900976,0.05512739569620471,0.04513946154555041,0.31857328180879246,500 -Linear Reg.,LGBM,0.95,0.314,0.31882517029399604,0.17873104426193565,0.998,1.0,0.12647219547900976,0.035017588858111126,0.04513946154555041,0.31857328180879246,500 -Linear Reg.,Logistic Regr.,0.9,0.86,0.2592281409673473,0.08970251629543106,1.0,1.0,0.06300567732617765,0.006719868195974334,0.05720312141493262,0.23496869651774063,500 -Linear Reg.,Logistic Regr.,0.95,0.974,0.30888938185760084,0.08970251629543106,1.0,1.0,0.06300567732617765,0.0014945204694376396,0.05720312141493262,0.23496869651774063,500 +Learner g,Learner m,level,Coverage,CI Length,Bias,Coverage (Lower),Coverage (Upper),RV,RVa,Bias (Lower),Bias (Upper),CI Bound Length,repetition +LGBM,LGBM,0.9,0.0,0.18456010957050942,0.17337243340910458,0.95,1.0,0.11953784478980024,0.07183819359114625,0.03136325666565055,0.31756830147937315,0.43226987036337094,100 +LGBM,LGBM,0.95,0.0,0.21991693474354154,0.17337243340910458,0.99,1.0,0.11953784478980024,0.057824558257341376,0.03136325666565055,0.31756830147937315,0.47305732219900365,100 +LGBM,Logistic Regr.,0.9,0.0,0.1812877961255176,0.15119732710067918,0.98,1.0,0.10434720873886104,0.05696666617037721,0.021677485760078018,0.29642303876248005,0.4317358050796223,100 +LGBM,Logistic Regr.,0.95,0.04,0.2160177327761319,0.15119732710067918,1.0,1.0,0.10434720873886104,0.04307608641610582,0.021677485760078018,0.29642303876248005,0.47178796413164226,100 +Linear Reg.,LGBM,0.9,0.0,0.18605425587835156,0.17399977827987403,0.94,1.0,0.11990868556653686,0.07190610415701137,0.03212149599746574,0.3182707711756607,0.4335852019350203,100 +Linear Reg.,LGBM,0.95,0.0,0.22169731988117328,0.17399977827987403,1.0,1.0,0.11990868556653686,0.057787828308944675,0.03212149599746574,0.3182707711756607,0.4747029383060918,100 +Linear Reg.,Logistic Regr.,0.9,0.57,0.1828627595282207,0.08930473133509086,1.0,1.0,0.06281551261599926,0.015943919677551065,0.05677070380042624,0.234795381047018,0.4334912988005637,100 +Linear Reg.,Logistic Regr.,0.95,0.72,0.21789441742191876,0.08930473133509086,1.0,1.0,0.06281551261599926,0.007945274454461557,0.05677070380042624,0.234795381047018,0.4738909034178378,100 diff --git a/results/irm/irm_ate_sensitivity_metadata.csv b/results/irm/irm_ate_sensitivity_metadata.csv index c038f240..4860b2a4 100644 --- a/results/irm/irm_ate_sensitivity_metadata.csv +++ b/results/irm/irm_ate_sensitivity_metadata.csv @@ -1,2 +1,2 @@ -DoubleML Version,Script,Date,Total Runtime (seconds),Python Version -0.10.dev0,irm_ate_sensitivity.py,2025-01-08 14:51:20,9351.407655954361,3.12.8 +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Number of observations,Number of repetitions +0.10.dev0,irm_ate_sensitivity.py,2025-02-17 13:05:22,1572.6342079639435,3.12.9,10000,100 diff --git a/results/irm/irm_atte_coverage.csv b/results/irm/irm_atte_coverage.csv index 3f486c5f..c319959f 100644 --- a/results/irm/irm_atte_coverage.csv +++ b/results/irm/irm_atte_coverage.csv @@ -1,9 +1,9 @@ Learner g,Learner m,level,Coverage,CI Length,Bias,repetition -Lasso,Logistic Regression,0.9,0.887,0.5331759172799809,0.13448723008182129,1000 -Lasso,Logistic Regression,0.95,0.942,0.6353182910443254,0.13448723008182129,1000 -Lasso,Random Forest,0.9,0.895,0.7320685495582887,0.18096827990713532,1000 -Lasso,Random Forest,0.95,0.948,0.8723134799586961,0.18096827990713532,1000 -Random Forest,Logistic Regression,0.9,0.871,0.5508997607447915,0.1495507044107987,1000 -Random Forest,Logistic Regression,0.95,0.928,0.6564375531412434,0.1495507044107987,1000 -Random Forest,Random Forest,0.9,0.901,0.7500491227406757,0.18282767194851973,1000 -Random Forest,Random Forest,0.95,0.948,0.8937386543823805,0.18282767194851973,1000 +Lasso,Logistic Regression,0.9,0.889,0.5331759172799809,0.1351855584332284,1000 +Lasso,Logistic Regression,0.95,0.937,0.6353182910443254,0.1351855584332284,1000 +Lasso,Random Forest,0.9,0.9,0.7254806184755598,0.18158167210289639,1000 +Lasso,Random Forest,0.95,0.955,0.864463475895592,0.18158167210289639,1000 +Random Forest,Logistic Regression,0.9,0.869,0.5498973024315036,0.14982335722803167,1000 +Random Forest,Logistic Regression,0.95,0.919,0.6552430503855857,0.14982335722803167,1000 +Random Forest,Random Forest,0.9,0.9,0.754141932378241,0.18774885787201864,1000 +Random Forest,Random Forest,0.95,0.944,0.8986155378653647,0.18774885787201864,1000 diff --git a/results/irm/irm_atte_coverage_metadata.csv b/results/irm/irm_atte_coverage_metadata.csv index 18af13a4..d250b9b4 100644 --- a/results/irm/irm_atte_coverage_metadata.csv +++ b/results/irm/irm_atte_coverage_metadata.csv @@ -1,2 +1,2 @@ -DoubleML Version,Script,Date,Total Runtime (seconds),Python Version -0.10.dev0,irm_atte_coverage.py,2025-01-08 13:14:10,3538.650551557541,3.12.8 +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Number of observations,Number of repetitions +0.10.dev0,irm_atte_coverage.py,2025-02-17 13:39:11,3599.8705773353577,3.12.9,500,1000 diff --git a/results/irm/irm_atte_sensitivity.csv b/results/irm/irm_atte_sensitivity.csv index 7483e17e..38c5529d 100644 --- a/results/irm/irm_atte_sensitivity.csv +++ b/results/irm/irm_atte_sensitivity.csv @@ -1,9 +1,9 @@ Learner g,Learner m,level,Coverage,CI Length,Bias,Coverage (Lower),Coverage (Upper),RV,RVa,Bias (Lower),Bias (Upper),repetition -LGBM,LGBM,0.9,0.702,0.348892741328716,0.13535473129404849,0.95,1.0,0.10509293446782587,0.024288450266824614,0.0649958270073569,0.25876253126104387,500 -LGBM,LGBM,0.95,0.826,0.41573134306126747,0.13535473129404849,0.982,1.0,0.10509293446782587,0.012452066983782133,0.0649958270073569,0.25876253126104387,500 -LGBM,Logistic Regr.,0.9,0.714,0.34666910502599596,0.13078975736827733,0.964,1.0,0.09817018528214369,0.022246998237972562,0.06545442330246619,0.2589976808508975,500 -LGBM,Logistic Regr.,0.95,0.834,0.41308171698108886,0.13078975736827733,0.984,1.0,0.09817018528214369,0.010949342084431752,0.06545442330246619,0.2589976808508975,500 -Linear Reg.,LGBM,0.9,0.754,0.3496967006881292,0.12455057551341779,0.962,1.0,0.09867724125956986,0.02021759355041513,0.06504946816195573,0.24393419014571052,500 -Linear Reg.,LGBM,0.95,0.858,0.416689319724762,0.12455057551341779,0.986,1.0,0.09867724125956986,0.009856683129418061,0.06504946816195573,0.24393419014571052,500 -Linear Reg.,Logistic Regr.,0.9,0.948,0.3502540540945954,0.07444772768321124,0.996,1.0,0.05840145836627319,0.0041811437412796835,0.09544484272838329,0.17545346180009289,500 -Linear Reg.,Logistic Regr.,0.95,0.976,0.41735344727108903,0.07444772768321124,0.998,1.0,0.05840145836627319,0.001573924919578166,0.09544484272838329,0.17545346180009289,500 +LGBM,LGBM,0.9,0.49,0.24396025943517974,0.12078793981792306,0.99,1.0,0.09312792462754033,0.028226037071058148,0.04146009435287488,0.25069761854571676,100 +LGBM,LGBM,0.95,0.67,0.2906965789036176,0.12078793981792306,1.0,1.0,0.09312792462754033,0.01585795047607936,0.04146009435287488,0.25069761854571676,100 +LGBM,Logistic Regr.,0.9,0.53,0.24426158440218287,0.11892095837667906,0.97,1.0,0.0924028336595813,0.027272723350173122,0.042150350065681345,0.2477418180760181,100 +LGBM,Logistic Regr.,0.95,0.66,0.2910556297475905,0.11892095837667906,1.0,1.0,0.0924028336595813,0.015364138052493198,0.042150350065681345,0.2477418180760181,100 +Linear Reg.,LGBM,0.9,0.51,0.24463965072027827,0.11756694550094914,0.98,1.0,0.0905239932672192,0.02606692134127496,0.041919687829977396,0.24763902427345857,100 +Linear Reg.,LGBM,0.95,0.71,0.2915061235514726,0.11756694550094914,1.0,1.0,0.0905239932672192,0.014359552326725858,0.041919687829977396,0.24763902427345857,100 +Linear Reg.,Logistic Regr.,0.9,0.95,0.24728188741382112,0.05084320840039999,1.0,1.0,0.04043535771019453,0.0022454610340575335,0.09338366625402501,0.1656214414367636,100 +Linear Reg.,Logistic Regr.,0.95,0.99,0.29465454276222763,0.05084320840039999,1.0,1.0,0.04043535771019453,0.00047406259110364837,0.09338366625402501,0.1656214414367636,100 diff --git a/results/irm/irm_atte_sensitivity_metadata.csv b/results/irm/irm_atte_sensitivity_metadata.csv index 43535236..2d2a013a 100644 --- a/results/irm/irm_atte_sensitivity_metadata.csv +++ b/results/irm/irm_atte_sensitivity_metadata.csv @@ -1,2 +1,2 @@ -DoubleML Version,Script,Date,Total Runtime (seconds),Python Version -0.10.dev0,irm_atte_sensitivity.py,2025-01-08 15:14:53,10782.067339658737,3.12.8 +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Number of observations,Number of repetitions +0.10.dev0,irm_atte_sensitivity.py,2025-02-17 13:05:01,1552.7866415977478,3.12.9,10000,100 diff --git a/results/irm/irm_cate_coverage.csv b/results/irm/irm_cate_coverage.csv index 42b5454d..788025e4 100644 --- a/results/irm/irm_cate_coverage.csv +++ b/results/irm/irm_cate_coverage.csv @@ -5,5 +5,5 @@ LGBM,Logistic Regression,0.9,0.8889914999999999,0.2354452920928757,0.05850699314 LGBM,Logistic Regression,0.95,0.942037,0.28055036951027373,0.05850699314525499,0.996,0.5990789624833718,1000 Lasso,LGBM,0.9,0.896289,0.6428742805642967,0.15829333167540963,1.0,1.6320898187101593,1000 Lasso,LGBM,0.95,0.9491539999999999,0.7660319531461224,0.15829333167540963,1.0,1.6337799284392311,1000 -Lasso,Logistic Regression,0.9,0.8892920000000001,0.2473839966772624,0.061668240816652016,0.998,0.629839636678611,1000 -Lasso,Logistic Regression,0.95,0.9413365,0.29477621345410787,0.061668240816652016,0.997,0.6288417962868702,1000 +Lasso,Logistic Regression,0.9,0.8892920000000001,0.24738399667726244,0.061668240816652016,0.998,0.629839636678611,1000 +Lasso,Logistic Regression,0.95,0.9413365,0.29477621345410787,0.061668240816652016,0.997,0.6288417962868703,1000 diff --git a/results/irm/irm_cate_coverage_metadata.csv b/results/irm/irm_cate_coverage_metadata.csv index 24d54ab6..ad0ca97f 100644 --- a/results/irm/irm_cate_coverage_metadata.csv +++ b/results/irm/irm_cate_coverage_metadata.csv @@ -1,2 +1,2 @@ -DoubleML Version,Script,Date,Total Runtime (seconds),Python Version -0.10.dev0,irm_cate_coverage.py,2025-01-08 14:19:05,7429.032528162003,3.12.8 +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Number of observations,Number of repetitions +0.10.dev0,irm_cate_coverage.py,2025-02-17 14:09:01,5386.899612903595,3.12.9,2000,1000 diff --git a/results/irm/irm_gate_coverage.csv b/results/irm/irm_gate_coverage.csv index f5dde966..4208c4b4 100644 --- a/results/irm/irm_gate_coverage.csv +++ b/results/irm/irm_gate_coverage.csv @@ -1,7 +1,7 @@ Learner g,Learner m,level,Coverage,CI Length,Bias,Uniform Coverage,Uniform CI Length,repetition LGBM,LGBM,0.9,0.941,2.1611526373873855,0.48511146422993706,1.0,5.063646757352436,1000 LGBM,LGBM,0.95,0.9766666666666667,2.5751722007164304,0.48511146422993706,1.0,5.085406768554818,1000 -LGBM,Logistic Regression,0.9,0.916,0.3904280687581478,0.08854865701618204,0.997,0.9186129412168421,1000 +LGBM,Logistic Regression,0.9,0.916,0.3904280687581478,0.08854865701618204,0.997,0.9186129412168422,1000 LGBM,Logistic Regression,0.95,0.9606666666666667,0.4652237383199531,0.08854865701618204,0.998,0.9197763724365766,1000 Lasso,LGBM,0.9,0.9043333333333333,2.047596498760145,0.49693265710872336,1.0,4.807194790243196,1000 Lasso,LGBM,0.95,0.959,2.43986171576749,0.49693265710872336,1.0,4.819825928994352,1000 diff --git a/results/irm/irm_gate_coverage_metadata.csv b/results/irm/irm_gate_coverage_metadata.csv index b7034be4..5273a309 100644 --- a/results/irm/irm_gate_coverage_metadata.csv +++ b/results/irm/irm_gate_coverage_metadata.csv @@ -1,2 +1,2 @@ -DoubleML Version,Script,Date,Total Runtime (seconds),Python Version -0.10.dev0,irm_gate_coverage.py,2025-01-08 13:08:04,3153.3980271816254,3.12.8 +DoubleML Version,Script,Date,Total Runtime (seconds),Python Version,Number of observations,Number of repetitions +0.10.dev0,irm_gate_coverage.py,2025-02-17 13:12:35,2007.2094383239746,3.12.9,500,1000 diff --git a/scripts/did/did_cs_atte_coverage.py b/scripts/did/did_cs_atte_coverage.py index 3886ea8e..e2d27963 100644 --- a/scripts/did/did_cs_atte_coverage.py +++ b/scripts/did/did_cs_atte_coverage.py @@ -125,6 +125,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/cvar_coverage.py b/scripts/irm/cvar_coverage.py index 4a166d94..4ab6e578 100644 --- a/scripts/irm/cvar_coverage.py +++ b/scripts/irm/cvar_coverage.py @@ -233,6 +233,8 @@ def dgp(n=200, p=5): 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/irm_apo_ate_sensitivity.py b/scripts/irm/irm_apo_ate_sensitivity.py new file mode 100644 index 00000000..4abc7d35 --- /dev/null +++ b/scripts/irm/irm_apo_ate_sensitivity.py @@ -0,0 +1,186 @@ +import numpy as np +import pandas as pd +from datetime import datetime +import time +import sys + +from sklearn.linear_model import LinearRegression, LogisticRegression +from lightgbm import LGBMRegressor, LGBMClassifier + +import doubleml as dml +from doubleml.datasets import make_confounded_irm_data + +# Number of repetitions +n_rep = 100 +max_runtime = 5.5 * 3600 # 5.5 hours in seconds + +# DGP pars +n_obs = 10000 +theta = 5.0 +trimming_threshold = 0.05 + +dgp_pars = { + "gamma_a": 0.198, + "beta_a": 0.582, + "theta": theta, + "var_epsilon_y": 1.0, + "trimming_threshold": trimming_threshold, + "linear": False, +} + +# test inputs +np.random.seed(42) +dgp_dict = make_confounded_irm_data(n_obs=int(1e+6), **dgp_pars) + +oracle_dict = dgp_dict['oracle_values'] +rho = oracle_dict['rho_ate'] +cf_y = oracle_dict['cf_y'] +cf_d = oracle_dict['cf_d_ate'] + +print(f"Confounding factor for Y: {cf_y}") +print(f"Confounding factor for D: {cf_d}") +print(f"Rho: {rho}") + +# to get the best possible comparison between different learners (and settings) we first simulate all datasets +np.random.seed(42) +datasets = [] +for i in range(n_rep): + dgp_dict = make_confounded_irm_data(n_obs=n_obs, **dgp_pars) + datasets.append(dgp_dict) + +# set up hyperparameters +hyperparam_dict = { + "learner_g": [("Linear Reg.", LinearRegression()), + ("LGBM", LGBMRegressor(n_estimators=500, learning_rate=0.01, min_child_samples=10, verbose=-1))], + "learner_m": [("Logistic Regr.", LogisticRegression()), + ("LGBM", LGBMClassifier(n_estimators=500, learning_rate=0.01, min_child_samples=10, verbose=-1)),], + "level": [0.95, 0.90], + "treatment_levels": [0.0, 1.0], +} + +# set up the results dataframe +df_results_detailed = pd.DataFrame() +sensitivity_err_count = 0 + +# start simulation +np.random.seed(42) +start_time = time.time() + +for i_rep in range(n_rep): + print(f"Repetition: {i_rep}/{n_rep}", end="\r") + + # Check the elapsed time + elapsed_time = time.time() - start_time + if elapsed_time > max_runtime: + print("Maximum runtime exceeded. Stopping the simulation.") + break + + # define the DoubleML data object + dgp_dict = datasets[i_rep] + + x_cols = [f'X{i + 1}' for i in np.arange(dgp_dict['x'].shape[1])] + df = pd.DataFrame(np.column_stack((dgp_dict['x'], dgp_dict['y'], dgp_dict['d'])), columns=x_cols + ['y', 'd']) + obj_dml_data = dml.DoubleMLData(df, 'y', 'd') + + for learner_g_idx, (learner_g_name, ml_g) in enumerate(hyperparam_dict["learner_g"]): + for learner_m_idx, (learner_m_name, ml_m) in enumerate(hyperparam_dict["learner_m"]): + # Set machine learning methods for g & m + # calculate the APOs + dml_apos = dml.DoubleMLAPOS( + obj_dml_data=obj_dml_data, + ml_g=ml_g, + ml_m=ml_m, + treatment_levels=hyperparam_dict["treatment_levels"], + ) + + for level_idx, level in enumerate(hyperparam_dict["level"]): + dml_apos.fit(n_jobs_cv=5) + effects = dml_apos.coef + dml_apos.fit(n_jobs_cv=5) + effects = dml_apos.coef + + causal_contrast_model = dml_apos.causal_contrast(reference_levels=0) + estimate = causal_contrast_model.thetas + + for level_idx, level in enumerate(hyperparam_dict["level"]): + # estimate = causal_contrast_model.coef[0] + confint = causal_contrast_model.confint(level=level) + coverage = (confint.iloc[0, 0] < theta) & (theta < confint.iloc[0, 1]) + ci_length = confint.iloc[0, 1] - confint.iloc[0, 0] + + # test sensitivity parameters + # try to run sensitivity analysis + try: + causal_contrast_model.sensitivity_analysis(cf_y=cf_y, cf_d=cf_d, rho=rho, level=level ,null_hypothesis=theta) + cover_lower = theta >= causal_contrast_model.sensitivity_params['ci']['lower'] + cover_upper = theta <= causal_contrast_model.sensitivity_params['ci']['upper'] + ci_bound_width = causal_contrast_model.sensitivity_params['ci']['upper'][0]- causal_contrast_model.sensitivity_params['ci']['lower'][0] + rv = causal_contrast_model.sensitivity_params['rv'] + rva = causal_contrast_model.sensitivity_params['rva'] + bias_lower = abs(theta - causal_contrast_model.sensitivity_params['theta']['lower']) + bias_upper = abs(theta - causal_contrast_model.sensitivity_params['theta']['upper']) + success_eval = 1 + except Exception as e: + sensitivity_err_count += 1 + continue + + df_results_detailed = pd.concat( + (df_results_detailed, + pd.DataFrame({ + "Coverage": coverage.astype(int), + "CI Length": confint.iloc[0, 1] - confint.iloc[0, 0], + "Bias": abs(estimate - theta), + "Coverage (Lower)": cover_lower.astype(int), + "Coverage (Upper)": cover_upper.astype(int), + "RV": rv, + "RVa": rva, + "Bias (Lower)": bias_lower, + "Bias (Upper)": bias_upper, + "CI Bound Length": ci_bound_width, + "Learner g": learner_g_name, + "Learner m": learner_m_name, + "level": level, + "repetition": i_rep, + "success_eval": success_eval}, index=[0])), + ignore_index=True) + +# aggregate results only if success_eval == 1 +df_results_detailed = df_results_detailed[df_results_detailed["success_eval"] == 1] + +df_results = df_results_detailed.groupby( + ["Learner g", "Learner m", "level"]).agg( + {"Coverage": "mean", + "CI Length": "mean", + "Bias": "mean", + "Coverage (Lower)": "mean", + "Coverage (Upper)": "mean", + "RV": "mean", + "RVa": "mean", + "Bias (Lower)": "mean", + "Bias (Upper)": "mean", + "CI Bound Length": "mean", + "repetition": "count"} + ).reset_index() +print(df_results) + +end_time = time.time() +total_runtime = end_time - start_time + +# save results +script_name = "irm_apo_sensitivity.py" +path = "results/irm/irm_apo_sensitivity" + +metadata = pd.DataFrame({ + 'DoubleML Version': [dml.__version__], + 'Script': [script_name], + 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], + 'Total Runtime (seconds)': [total_runtime], + 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Sensitivity Errors': [sensitivity_err_count], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], +}) +print(metadata) + +df_results.to_csv(f"{path}.csv", index=False) +metadata.to_csv(f"{path}_metadata.csv", index=False) diff --git a/scripts/irm/irm_apo_coverage.py b/scripts/irm/irm_apo_coverage.py index 3f5f7aec..96f1ceb7 100644 --- a/scripts/irm/irm_apo_coverage.py +++ b/scripts/irm/irm_apo_coverage.py @@ -218,6 +218,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/irm_ate_coverage.py b/scripts/irm/irm_ate_coverage.py index 82e7fa7e..36a656d7 100644 --- a/scripts/irm/irm_ate_coverage.py +++ b/scripts/irm/irm_ate_coverage.py @@ -105,6 +105,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/irm_ate_sensitivity.py b/scripts/irm/irm_ate_sensitivity.py index d29fa778..31ddfe89 100644 --- a/scripts/irm/irm_ate_sensitivity.py +++ b/scripts/irm/irm_ate_sensitivity.py @@ -11,11 +11,11 @@ from doubleml.datasets import make_confounded_irm_data # Number of repetitions -n_rep = 500 +n_rep = 100 max_runtime = 5.5 * 3600 # 5.5 hours in seconds # DGP pars -n_obs = 5000 +n_obs = 10000 theta = 5.0 trimming_threshold = 0.05 @@ -101,6 +101,7 @@ dml_irm.sensitivity_analysis(cf_y=cf_y, cf_d=cf_d, rho=rho, level=level, null_hypothesis=theta) cover_lower = theta >= dml_irm.sensitivity_params['ci']['lower'] cover_upper = theta <= dml_irm.sensitivity_params['ci']['upper'] + ci_bound_width = dml_irm.sensitivity_params['ci']['upper'][0]- dml_irm.sensitivity_params['ci']['lower'][0] rv = dml_irm.sensitivity_params['rv'] rva = dml_irm.sensitivity_params['rva'] bias_lower = abs(theta - dml_irm.sensitivity_params['theta']['lower']) @@ -118,6 +119,7 @@ "RVa": rva, "Bias (Lower)": bias_lower, "Bias (Upper)": bias_upper, + "CI Bound Length": ci_bound_width, "Learner g": learner_g_name, "Learner m": learner_m_name, "level": level, @@ -135,6 +137,7 @@ "RVa": "mean", "Bias (Lower)": "mean", "Bias (Upper)": "mean", + "CI Bound Length": "mean", "repetition": "count"} ).reset_index() print(df_results) @@ -152,6 +155,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/irm_atte_coverage.py b/scripts/irm/irm_atte_coverage.py index 8b847436..bc156c2f 100644 --- a/scripts/irm/irm_atte_coverage.py +++ b/scripts/irm/irm_atte_coverage.py @@ -134,6 +134,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/irm_atte_sensitivity.py b/scripts/irm/irm_atte_sensitivity.py index aeda6b1a..1412985b 100644 --- a/scripts/irm/irm_atte_sensitivity.py +++ b/scripts/irm/irm_atte_sensitivity.py @@ -11,11 +11,11 @@ from doubleml.datasets import make_confounded_irm_data # Number of repetitions -n_rep = 500 +n_rep = 100 max_runtime = 5.5 * 3600 # 5.5 hours in seconds # DGP pars -n_obs = 5000 +n_obs = 10000 theta = 5.0 trimming_threshold = 0.05 @@ -152,6 +152,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/irm_cate_coverage.py b/scripts/irm/irm_cate_coverage.py index 9843de13..67d658e4 100644 --- a/scripts/irm/irm_cate_coverage.py +++ b/scripts/irm/irm_cate_coverage.py @@ -120,6 +120,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/irm_gate_coverage.py b/scripts/irm/irm_gate_coverage.py index aafd9151..3c097483 100644 --- a/scripts/irm/irm_gate_coverage.py +++ b/scripts/irm/irm_gate_coverage.py @@ -123,6 +123,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/lpq_coverage.py b/scripts/irm/lpq_coverage.py index ab1152ae..a9bfb5a0 100644 --- a/scripts/irm/lpq_coverage.py +++ b/scripts/irm/lpq_coverage.py @@ -250,6 +250,8 @@ def dgp(n=200, p=5): 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/pq_coverage.py b/scripts/irm/pq_coverage.py index 2a6a7270..bcf9a797 100644 --- a/scripts/irm/pq_coverage.py +++ b/scripts/irm/pq_coverage.py @@ -229,6 +229,8 @@ def dgp(n=200, p=5): 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/ssm_mar_ate_coverage.py b/scripts/irm/ssm_mar_ate_coverage.py index 8aa0e561..a6247320 100644 --- a/scripts/irm/ssm_mar_ate_coverage.py +++ b/scripts/irm/ssm_mar_ate_coverage.py @@ -117,6 +117,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/irm/ssm_nonignorable_ate_coverage.py b/scripts/irm/ssm_nonignorable_ate_coverage.py index ebbacfe3..8435aa0a 100644 --- a/scripts/irm/ssm_nonignorable_ate_coverage.py +++ b/scripts/irm/ssm_nonignorable_ate_coverage.py @@ -117,6 +117,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/plm/pliv_late_coverage.py b/scripts/plm/pliv_late_coverage.py index 6cc43290..9806d0ea 100644 --- a/scripts/plm/pliv_late_coverage.py +++ b/scripts/plm/pliv_late_coverage.py @@ -129,6 +129,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/plm/plr_ate_coverage.py b/scripts/plm/plr_ate_coverage.py index ab715d77..2af49b55 100644 --- a/scripts/plm/plr_ate_coverage.py +++ b/scripts/plm/plr_ate_coverage.py @@ -121,6 +121,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/plm/plr_ate_sensitivity.py b/scripts/plm/plr_ate_sensitivity.py index 34e73dc6..bc60bae6 100644 --- a/scripts/plm/plr_ate_sensitivity.py +++ b/scripts/plm/plr_ate_sensitivity.py @@ -166,6 +166,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/plm/plr_cate_coverage.py b/scripts/plm/plr_cate_coverage.py index 5fee83d4..d07aa53e 100644 --- a/scripts/plm/plr_cate_coverage.py +++ b/scripts/plm/plr_cate_coverage.py @@ -120,6 +120,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/plm/plr_gate_coverage.py b/scripts/plm/plr_gate_coverage.py index 7c996696..ea41e84c 100644 --- a/scripts/plm/plr_gate_coverage.py +++ b/scripts/plm/plr_gate_coverage.py @@ -123,6 +123,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata) diff --git a/scripts/rdd/rdd_sharp_coverage.py b/scripts/rdd/rdd_sharp_coverage.py index bff96f94..f3cd6e46 100644 --- a/scripts/rdd/rdd_sharp_coverage.py +++ b/scripts/rdd/rdd_sharp_coverage.py @@ -155,6 +155,8 @@ 'Date': [datetime.now().strftime("%Y-%m-%d %H:%M:%S")], 'Total Runtime (seconds)': [total_runtime], 'Python Version': [f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"], + 'Number of observations': [n_obs], + 'Number of repetitions': [n_rep], }) print(metadata)