Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def __init__(self, model_final):
def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
sample_weight=None, freq_weight=None, sample_var=None, groups=None):
Y_res, T_res = nuisances
self._model_final.fit(X, T, T_res, Y_res, sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var)
self._model_final.fit(X, T, T_res, Y_res, **(filter_none_kwargs(sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var, groups=groups)))
return self

def predict(self, X=None):
Expand Down
8 changes: 5 additions & 3 deletions econml/iv/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
XT_res = self._combine(X, T_res)
XZ_res = self._combine(X, Z_res)
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var)
freq_weight=freq_weight, sample_var=sample_var, groups=groups)

self._model_final.fit(XZ_res, XT_res, Y_res, **filtered_kwargs)

Expand Down Expand Up @@ -376,14 +376,16 @@ def __init__(self, *,
mc_iters=None,
mc_agg='mean',
random_state=None,
allow_missing=False):
allow_missing=False,
cov_type="HC0"):
self.model_y_xw = clone(model_y_xw, safe=False)
self.model_t_xw = clone(model_t_xw, safe=False)
self.model_t_xwz = clone(model_t_xwz, safe=False)
self.model_z_xw = clone(model_z_xw, safe=False)
self.projection = projection
self.featurizer = clone(featurizer, safe=False)
self.fit_cate_intercept = fit_cate_intercept
self.cov_type = cov_type

super().__init__(discrete_outcome=discrete_outcome,
discrete_instrument=discrete_instrument,
Expand All @@ -403,7 +405,7 @@ def _gen_featurizer(self):
return clone(self.featurizer, safe=False)

def _gen_model_final(self):
return StatsModels2SLS(cov_type="HC0")
return StatsModels2SLS(cov_type=self.cov_type)

def _gen_ortho_learner_model_final(self):
return _OrthoIVModelFinal(self._gen_model_final(), self._gen_featurizer(), self.fit_cate_intercept)
Expand Down
149 changes: 130 additions & 19 deletions econml/sklearn_extensions/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,7 +1693,7 @@ class StatsModelsLinearRegression(_StatsModelsWrapper):
fit_intercept : bool, default True
Whether to fit an intercept in this model
cov_type : string, default "HC0"
The covariance approach to use. Supported values are "HCO", "HC1", and "nonrobust".
The covariance approach to use. Supported values are "HC0", "HC1", "nonrobust", and "clustered".
enable_federation : bool, default False
Whether to enable federation (aggregating this model's results with other models in a distributed setting).
This requires additional memory proportional to the number of columns in X to the fourth power.
Expand All @@ -1704,10 +1704,10 @@ def __init__(self, fit_intercept=True, cov_type="HC0", *, enable_federation=Fals
self.fit_intercept = fit_intercept
self.enable_federation = enable_federation

def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
def _check_input(self, X, y, sample_weight, freq_weight, sample_var, groups=None):
"""Check dimensions and other assertions."""
X, y, sample_weight, freq_weight, sample_var = check_input_arrays(
X, y, sample_weight, freq_weight, sample_var, dtype='numeric')
X, y, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
X, y, sample_weight, freq_weight, sample_var, groups, dtype='numeric')
if X is None:
X = np.empty((y.shape[0], 0))
if self.fit_intercept:
Expand All @@ -1720,6 +1720,8 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
freq_weight = np.ones(y.shape[0])
if sample_var is None:
sample_var = np.zeros(y.shape)
if groups is None:
groups = np.arange(y.shape[0])

# check freq_weight should be integer and should be accompanied by sample_var
if np.any(np.not_equal(np.mod(freq_weight, 1), 0)):
Expand Down Expand Up @@ -1753,7 +1755,7 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):

# check array shape
assert (X.shape[0] == y.shape[0] == sample_weight.shape[0] ==
freq_weight.shape[0] == sample_var.shape[0]), "Input lengths not compatible!"
freq_weight.shape[0] == sample_var.shape[0] == groups.shape[0]), "Input lengths not compatible!"
if y.ndim >= 2:
assert (y.ndim == sample_var.ndim and
y.shape[1] == sample_var.shape[1]), "Input shapes not compatible: {}, {}!".format(
Expand All @@ -1767,9 +1769,9 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
else:
weighted_y = y * np.sqrt(sample_weight).reshape(-1, 1)
sample_var = sample_var * (sample_weight.reshape(-1, 1))
return weighted_X, weighted_y, freq_weight, sample_var
return weighted_X, weighted_y, freq_weight, sample_var, groups

def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
"""
Fits the model.

Expand All @@ -1788,13 +1790,15 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
sample_var : {(N,), (N, p)} nd array_like or None
Variance of the outcome(s) of the original freq_weight[i] observations that were used to
compute the mean outcome represented by observation i.
groups : (N,) array_like or None
Group labels for clustered standard errors.

Returns
-------
self : StatsModelsLinearRegression
"""
# TODO: Add other types of covariance estimation (e.g. Newey-West (HAC), HC2, HC3)
X, y, freq_weight, sample_var = self._check_input(X, y, sample_weight, freq_weight, sample_var)
X, y, freq_weight, sample_var, groups = self._check_input(X, y, sample_weight, freq_weight, sample_var, groups)

WX = X * np.sqrt(freq_weight).reshape(-1, 1)

Expand Down Expand Up @@ -1840,6 +1844,8 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
self.XXXy = np.einsum('nx,ny->yx', WX, wy)
self.XXXX = np.einsum('nw,nx->wx', WX, WX)
self.sample_var = np.average(sv, weights=freq_weight, axis=0) * n_obs
elif self.cov_type == 'clustered':
raise AttributeError("Clustered standard errors are not supported with federation enabled.")

sigma_inv = np.linalg.pinv(self.XX)

Expand Down Expand Up @@ -1871,8 +1877,10 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
for j in range(self._n_out):
weighted_sigma = np.matmul(WX.T, WX * var_i[:, [j]])
self._var.append(correction * np.matmul(sigma_inv, np.matmul(weighted_sigma, sigma_inv)))
elif (self.cov_type == 'clustered'):
self._var = self._compute_clustered_variance_linear(WX, y - np.matmul(X, param), sigma_inv, groups)
else:
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1.")
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered.")

self._param_var = np.array(self._var)

Expand Down Expand Up @@ -1937,7 +1945,6 @@ def aggregate(models: List[StatsModelsLinearRegression]):
agg_model._var = correction * np.matmul(sigma_inv, np.matmul(weighted_sigma.squeeze(0), sigma_inv))
else:
agg_model._var = [correction * np.matmul(sigma_inv, np.matmul(ws, sigma_inv)) for ws in weighted_sigma]

else:
assert agg_model.cov_type == 'nonrobust' or agg_model.cov_type is None
sigma = XXyy - 2 * np.einsum('yx,xy->y', XXXy, param) + np.einsum('wx,wy,xy->y', XXXX, param, param)
Expand All @@ -1954,6 +1961,54 @@ def aggregate(models: List[StatsModelsLinearRegression]):

return agg_model

def _compute_clustered_variance_linear(self, WX, eps_i, sigma_inv, groups):
"""
Compute clustered standard errors for linear regression.

Parameters
----------
WX : array_like
Weighted design matrix
eps_i : array_like
Residuals
sigma_inv : array_like
Inverse of X.T @ X
groups : array_like
Group labels for clustering

Returns
-------
var : array_like or list
Clustered variance matrix
"""
n, k = WX.shape
group_ids, inverse_idx = np.unique(groups, return_inverse=True)
n_groups = len(group_ids)

# Group correction factor
group_correction = (n_groups / (n_groups - 1))

if eps_i.ndim < 2:
# Single outcome case
WX_e = WX * eps_i.reshape(-1, 1)
group_sums = np.zeros((n_groups, k))
np.add.at(group_sums, inverse_idx, WX_e)
s = group_sums.T @ group_sums

return group_correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv))
else:
# Multiple outcome case
var_list = []
for j in range(eps_i.shape[1]):
WX_e = WX * eps_i[:, [j]]
group_sums = np.zeros((n_groups, k))
np.add.at(group_sums, inverse_idx, WX_e)
s = group_sums.T @ group_sums

var_list.append(group_correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv)))

return var_list


class StatsModelsRLM(_StatsModelsWrapper):
"""
Expand Down Expand Up @@ -2040,23 +2095,28 @@ class StatsModels2SLS(_StatsModelsWrapper):

Parameters
----------
cov_type : {'HC0', 'HC1', 'nonrobust', or None}, default 'HC0'
Indicates how the covariance matrix is estimated.
cov_type : {'HC0', 'HC1', 'nonrobust', 'clustered', or None}, default 'HC0'
Indicates how the covariance matrix is estimated. 'clustered' requires groups to be provided in fit().
"""

def __init__(self, cov_type="HC0"):
self.fit_intercept = False
self.cov_type = cov_type
return

def _check_input(self, Z, T, y, sample_weight):
def _check_input(self, Z, T, y, sample_weight, groups=None):
"""Check dimensions and other assertions."""
# set default values for None
if sample_weight is None:
sample_weight = np.ones(y.shape[0])
if groups is None:
groups = np.arange(y.shape[0])
else:
groups = np.asarray(groups)

# check array shape
assert (T.shape[0] == Z.shape[0] == y.shape[0] == sample_weight.shape[0]), "Input lengths not compatible!"
assert (T.shape[0] == Z.shape[0] == y.shape[0] == sample_weight.shape[0] == groups.shape[0]), \
"Input lengths not compatible!"

# check dimension of instruments is more than dimension of treatments
if Z.shape[1] < T.shape[1]:
Expand All @@ -2073,9 +2133,9 @@ def _check_input(self, Z, T, y, sample_weight):
weighted_y = y * np.sqrt(sample_weight)
else:
weighted_y = y * np.sqrt(sample_weight).reshape(-1, 1)
return weighted_Z, weighted_T, weighted_y
return weighted_Z, weighted_T, weighted_y, groups

def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
"""
Fits the model.

Expand All @@ -2096,7 +2156,8 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
sample_var : {(N,), (N, p)} nd array_like or None
Variance of the outcome(s) of the original freq_weight[i] observations that were used to
compute the mean outcome represented by observation i.

groups : (N,) array_like or None
Group labels for clustered standard errors. Required when cov_type='clustered'.

Returns
-------
Expand All @@ -2105,7 +2166,7 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
assert freq_weight is None, "freq_weight is not supported yet for this class!"
assert sample_var is None, "sample_var is not supported yet for this class!"

Z, T, y = self._check_input(Z, T, y, sample_weight)
Z, T, y, groups = self._check_input(Z, T, y, sample_weight, groups)

self._n_out = 0 if y.ndim < 2 else y.shape[1]

Expand Down Expand Up @@ -2164,8 +2225,58 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
weighted_sigma = np.matmul(that.T, that * var_i[:, [j]])
self._var.append(correction * np.matmul(thatT_that_inv,
np.matmul(weighted_sigma, thatT_that_inv)))
elif (self.cov_type == 'clustered'):
self._var = self._compute_clustered_variance(that, y - np.dot(T, param), thatT_that_inv, groups)
else:
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1.")
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered.")

self._param_var = np.array(self._var)
return self

def _compute_clustered_variance(self, that, eps_i, thatT_that_inv, groups):
"""
Compute clustered standard errors.

Parameters
----------
that : array_like
Fitted values from first stage
eps_i : array_like
Residuals
thatT_that_inv : array_like
Inverse of that.T @ that
groups : array_like
Group labels for clustering

Returns
-------
var : array_like or list
Clustered variance matrix
"""
n, k = that.shape
group_ids, inverse_idx = np.unique(groups, return_inverse=True)
n_groups = len(group_ids)

# Group correction factor
group_correction = (n_groups / (n_groups - 1))

if eps_i.ndim < 2:
# Single outcome case
that_e = that * eps_i.reshape(-1, 1)
group_sums = np.zeros((n_groups, k))
np.add.at(group_sums, inverse_idx, that_e)
s = group_sums.T @ group_sums

return group_correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv))
else:
# Multiple outcome case
var_list = []
for j in range(eps_i.shape[1]):
that_e = that * eps_i[:, [j]]
group_sums = np.zeros((n_groups, k))
np.add.at(group_sums, inverse_idx, that_e)
s = group_sums.T @ group_sums

var_list.append(group_correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv)))

return var_list
Loading
Loading