Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def __init__(self, model_y: ModelSelector, model_t: ModelSelector):
self._model_y = model_y
self._model_t = model_t

def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
assert Z is None, "Cannot accept instrument!"
self._model_t.train(is_selecting, folds, X, W, T, **
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
self._model_y.train(is_selecting, folds, X, W, Y, **
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
return self

def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None,
Expand Down
4 changes: 2 additions & 2 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, model: SingleModelSelector, discrete_target):
self._model = clone(model, safe=False)
self._discrete_target = discrete_target

def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None):
def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None, **fit_params):
if self._discrete_target:
# In this case, the Target is the one-hot-encoding of the treatment variable
# We need to go back to the label representation of the one-hot so as to call
Expand All @@ -108,7 +108,7 @@ def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=No
Target = inverse_onehot(Target)

self._model.train(is_selecting, folds, _combine(X, W, Target.shape[0]), Target,
**filter_none_kwargs(groups=groups, sample_weight=sample_weight))
**filter_none_kwargs(groups=groups, sample_weight=sample_weight, **fit_params))
return self

@property
Expand Down
4 changes: 2 additions & 2 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self,
def _combine(self, X, W):
return np.hstack([arr for arr in [X, W] if arr is not None])

def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None, groups=None):
def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None, groups=None, **fit_params):
if Y.ndim != 1 and (Y.ndim != 2 or Y.shape[1] != 1):
raise ValueError("The outcome matrix must be of shape ({0}, ) or ({0}, 1), "
"instead got {1}.".format(len(X), Y.shape))
Expand All @@ -80,7 +80,7 @@ def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None
raise AttributeError("Provided crossfit folds contain training splits that " +
"don't contain all treatments")
XW = self._combine(X, W)
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight)
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, **fit_params)

self._model_propensity.train(is_selecting, folds, XW, inverse_onehot(T), groups=groups, **filtered_kwargs)
self._model_regression.train(is_selecting, folds, np.hstack([XW, T]), Y, groups=groups, **filtered_kwargs)
Expand Down
Loading
Loading