Skip to content

Commit 92081ad

Browse files
authored
Merge pull request #338 from DoubleML/305-feature-request-integrate-clusters-into-the-doublemldata-class
Integrate clusters into the `DoubleMLData` class, Refactor data generators, refactor sampling using Mixin Class
2 parents d61c040 + d80459b commit 92081ad

File tree

111 files changed

+3453
-2624
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+3453
-2624
lines changed

.github/ISSUE_TEMPLATE/bug_report.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@ body:
2424
label: Minimum reproducible code snippet
2525
description: |
2626
Please provide a short reproducible code snippet. Example:
27-
2827
```python
2928
import numpy as np
3029
import doubleml as dml
31-
from doubleml.datasets import make_plr_CCDDHNR2018
30+
from doubleml.plm.datasets import make_plr_CCDDHNR2018
3231
from sklearn.ensemble import RandomForestRegressor
3332
from sklearn.base import clone
3433
np.random.seed(3141)

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ To submit a **bug report**, you can use our
1515
```python
1616
import numpy as np
1717
import doubleml as dml
18-
from doubleml.datasets import make_plr_CCDDHNR2018
18+
from doubleml.plm.datasets import make_plr_CCDDHNR2018
1919
from sklearn.ensemble import RandomForestRegressor
2020
from sklearn.base import clone
2121
np.random.seed(3141)

doubleml/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import importlib.metadata
22

3-
from .data import DoubleMLClusterData, DoubleMLData
3+
from .data import DoubleMLClusterData, DoubleMLData, DoubleMLDIDData, DoubleMLPanelData, DoubleMLRDDData, DoubleMLSSMData
44
from .did.did import DoubleMLDID
55
from .did.did_cs import DoubleMLDIDCS
66
from .double_ml_framework import DoubleMLFramework, concat
@@ -29,6 +29,10 @@
2929
"DoubleMLIIVM",
3030
"DoubleMLData",
3131
"DoubleMLClusterData",
32+
"DoubleMLDIDData",
33+
"DoubleMLPanelData",
34+
"DoubleMLRDDData",
35+
"DoubleMLSSMData",
3236
"DoubleMLDID",
3337
"DoubleMLDIDCS",
3438
"DoubleMLPQ",

doubleml/data/__init__.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,78 @@
22
The :mod:`doubleml.data` module implements data classes for double machine learning.
33
"""
44

5+
import warnings
6+
57
from .base_data import DoubleMLData
6-
from .cluster_data import DoubleMLClusterData
8+
from .did_data import DoubleMLDIDData
79
from .panel_data import DoubleMLPanelData
10+
from .rdd_data import DoubleMLRDDData
11+
from .ssm_data import DoubleMLSSMData
12+
13+
14+
class DoubleMLClusterData(DoubleMLData):
15+
"""
16+
Backwards compatibility wrapper for DoubleMLData with cluster_cols.
17+
This class is deprecated and will be removed in a future version.
18+
Use DoubleMLData with cluster_cols instead.
19+
"""
20+
21+
def __init__(
22+
self,
23+
data,
24+
y_col,
25+
d_cols,
26+
cluster_cols,
27+
x_cols=None,
28+
z_cols=None,
29+
t_col=None,
30+
s_col=None,
31+
use_other_treat_as_covariate=True,
32+
force_all_x_finite=True,
33+
):
34+
warnings.warn(
35+
"DoubleMLClusterData is deprecated and will be removed with version 0.12.0. "
36+
"Use DoubleMLData with cluster_cols instead.",
37+
FutureWarning,
38+
stacklevel=2,
39+
)
40+
super().__init__(
41+
data=data,
42+
y_col=y_col,
43+
d_cols=d_cols,
44+
x_cols=x_cols,
45+
z_cols=z_cols,
46+
cluster_cols=cluster_cols,
47+
use_other_treat_as_covariate=use_other_treat_as_covariate,
48+
force_all_x_finite=force_all_x_finite,
49+
force_all_d_finite=True,
50+
)
51+
52+
@classmethod
53+
def from_arrays(
54+
cls, x, y, d, cluster_vars, z=None, t=None, s=None, use_other_treat_as_covariate=True, force_all_x_finite=True
55+
):
56+
"""
57+
Initialize :class:`DoubleMLClusterData` from :class:`numpy.ndarray`'s.
58+
This method is deprecated and will be removed with version 0.12.0,
59+
use DoubleMLData.from_arrays with cluster_vars instead.
60+
"""
61+
warnings.warn(
62+
"DoubleMLClusterData is deprecated and will be removed with version 0.12.0. "
63+
"Use DoubleMLData.from_arrays with cluster_vars instead.",
64+
FutureWarning,
65+
stacklevel=2,
66+
)
67+
return DoubleMLData.from_arrays(
68+
x=x,
69+
y=y,
70+
d=d,
71+
z=z,
72+
cluster_vars=cluster_vars,
73+
use_other_treat_as_covariate=use_other_treat_as_covariate,
74+
force_all_x_finite=force_all_x_finite,
75+
force_all_d_finite=True,
76+
)
77+
878

9-
__all__ = [
10-
"DoubleMLData",
11-
"DoubleMLClusterData",
12-
"DoubleMLPanelData",
13-
]
79+
__all__ = ["DoubleMLData", "DoubleMLClusterData", "DoubleMLDIDData", "DoubleMLPanelData", "DoubleMLRDDData", "DoubleMLSSMData"]

0 commit comments

Comments
 (0)