Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doubleml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .data import DoubleMLClusterData, DoubleMLData, DoubleMLDIDData, DoubleMLPanelData, DoubleMLRDDData, DoubleMLSSMData
from .did.did import DoubleMLDID
from .did.did_cs import DoubleMLDIDCS
from .double_ml_framework import DoubleMLFramework, concat
from .double_ml_framework import DoubleMLCore, DoubleMLFramework, concat
from .irm.apo import DoubleMLAPO
from .irm.apos import DoubleMLAPOS
from .irm.cvar import DoubleMLCVAR
Expand All @@ -21,6 +21,7 @@

__all__ = [
"concat",
"DoubleMLCore",
"DoubleMLFramework",
"DoubleMLPLR",
"DoubleMLPLIV",
Expand Down
5 changes: 3 additions & 2 deletions doubleml/did/tests/test_did_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from doubleml.did.did_aggregation import DoubleMLDIDAggregation
from doubleml.double_ml_framework import DoubleMLFramework
from doubleml.double_ml_framework import DoubleMLCore, DoubleMLFramework
from doubleml.tests._utils import generate_dml_dict


Expand All @@ -28,7 +28,8 @@ def base_framework(n_rep):
psi_b = np.random.normal(size=(n_obs, n_thetas, n_rep))

doubleml_dict = generate_dml_dict(psi_a, psi_b)
return DoubleMLFramework(doubleml_dict)
dml_core = DoubleMLCore(**doubleml_dict)
return DoubleMLFramework(dml_core=dml_core)


@pytest.fixture(scope="module", params=["ones", "random", "zeros", "mixed"])
Expand Down
8 changes: 5 additions & 3 deletions doubleml/did/tests/test_did_aggregation_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from doubleml.did.did_aggregation import DoubleMLDIDAggregation
from doubleml.double_ml_framework import DoubleMLFramework
from doubleml.double_ml_framework import DoubleMLCore, DoubleMLFramework
from doubleml.tests._utils import generate_dml_dict


Expand All @@ -24,7 +24,8 @@ def mock_framework(n_rep, n_thetas):
psi_a = np.ones(shape=(n_obs, n_thetas, n_rep))
psi_b = np.random.normal(size=(n_obs, n_thetas, n_rep))
doubleml_dict = generate_dml_dict(psi_a, psi_b)
return DoubleMLFramework(doubleml_dict)
dml_core = DoubleMLCore(**doubleml_dict)
return DoubleMLFramework(dml_core)


@pytest.fixture
Expand Down Expand Up @@ -67,7 +68,8 @@ def test_invalid_framework_dim():
psi_a = np.ones(shape=(10, 2, 1))
psi_b = np.random.normal(size=(10, 2, 1))
doubleml_dict = generate_dml_dict(psi_a, psi_b)
framework = DoubleMLFramework(doubleml_dict)
dml_core = DoubleMLCore(**doubleml_dict)
framework = DoubleMLFramework(dml_core=dml_core)

# Test with invalid framework dimension
with pytest.raises(ValueError, match="All frameworks must be one-dimensional"):
Expand Down
5 changes: 3 additions & 2 deletions doubleml/did/tests/test_did_aggregation_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from matplotlib.figure import Figure

from doubleml.did.did_aggregation import DoubleMLDIDAggregation
from doubleml.double_ml_framework import DoubleMLFramework
from doubleml.double_ml_framework import DoubleMLCore, DoubleMLFramework
from doubleml.tests._utils import generate_dml_dict


Expand All @@ -23,7 +23,8 @@ def mock_framework(n_rep):
psi_a = np.ones(shape=(n_obs, n_thetas, n_rep))
psi_b = np.random.normal(size=(n_obs, n_thetas, n_rep))
doubleml_dict = generate_dml_dict(psi_a, psi_b)
return DoubleMLFramework(doubleml_dict)
dml_core = DoubleMLCore(**doubleml_dict)
return DoubleMLFramework(dml_core=dml_core)


@pytest.fixture
Expand Down
5 changes: 3 additions & 2 deletions doubleml/did/tests/test_did_aggregation_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from matplotlib.figure import Figure

from doubleml.did.did_aggregation import DoubleMLDIDAggregation
from doubleml.double_ml_framework import DoubleMLFramework
from doubleml.double_ml_framework import DoubleMLCore, DoubleMLFramework
from doubleml.tests._utils import generate_dml_dict


Expand All @@ -24,7 +24,8 @@ def mock_framework(n_rep):
psi_a = np.ones(shape=(n_obs, n_thetas, n_rep))
psi_b = np.random.normal(size=(n_obs, n_thetas, n_rep))
doubleml_dict = generate_dml_dict(psi_a, psi_b)
return DoubleMLFramework(doubleml_dict)
dml_core = DoubleMLCore(**doubleml_dict)
return DoubleMLFramework(dml_core=dml_core)


@pytest.fixture
Expand Down
9 changes: 3 additions & 6 deletions doubleml/double_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from doubleml.data import DoubleMLDIDData, DoubleMLPanelData, DoubleMLRDDData, DoubleMLSSMData
from doubleml.data.base_data import DoubleMLBaseData
from doubleml.double_ml_framework import DoubleMLFramework
from doubleml.double_ml_framework import DoubleMLCore, DoubleMLFramework
from doubleml.double_ml_sampling_mixins import SampleSplittingMixin
from doubleml.utils._checks import _check_external_predictions
from doubleml.utils._estimation import _aggregate_coefs_and_ses, _rmse, _set_external_predictions, _var_est
Expand Down Expand Up @@ -625,14 +625,11 @@ def construct_framework(self):
scaled_psi_reshape = np.transpose(scaled_psi, (0, 2, 1))

doubleml_dict = {
"thetas": self.coef,
"all_thetas": self.all_coef,
"ses": self.se,
"all_ses": self.all_se,
"var_scaling_factors": self._var_scaling_factors,
"scaled_psi": scaled_psi_reshape,
"is_cluster_data": self._is_cluster_data,
"treatment_names": self._dml_data.d_cols,
}

if self._sensitivity_implemented:
Expand Down Expand Up @@ -669,8 +666,8 @@ def construct_framework(self):
},
}
)

doubleml_framework = DoubleMLFramework(doubleml_dict)
dml_core = DoubleMLCore(**doubleml_dict)
doubleml_framework = DoubleMLFramework(dml_core=dml_core, treatment_names=self._dml_data.d_cols)
return doubleml_framework

def bootstrap(self, method="normal", n_rep_boot=500):
Expand Down
Loading
Loading