Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 57 additions & 65 deletions boruta/boruta_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectorMixin
from sklearn.utils.validation import check_is_fitted
from sklearn.utils._set_output import _get_output_config
import warnings


Expand Down Expand Up @@ -142,6 +143,10 @@ class BorutaPy(BaseEstimator, SelectorMixin):
The mask of selected tentative features, which haven't gained enough
support during the max_iter number of iterations.

weak : bool, default=False

If set to true, the tentative features are also used to reduce X.

ranking_ : array of shape [n_features]

The feature ranking, such that ``ranking_[i]`` corresponds to the
Expand Down Expand Up @@ -194,7 +199,7 @@ class BorutaPy(BaseEstimator, SelectorMixin):

def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
two_step=True, max_iter=100, random_state=None, verbose=0,
early_stopping=False, n_iter_no_change=20):
early_stopping=False, n_iter_no_change=20, weak: bool = False):
self.estimator = estimator
self.n_estimators = n_estimators
self.perc = perc
Expand All @@ -207,8 +212,9 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
self.n_iter_no_change = n_iter_no_change
self.__version__ = '0.3'
self._is_lightgbm = 'lightgbm' in str(type(self.estimator))
self.weak = weak

def fit(self, X, y):
def fit(self, X, y, **fit_params):
"""
Fits the Boruta feature selection with the provided estimator.

Expand All @@ -223,7 +229,7 @@ def fit(self, X, y):

return self._fit(X, y)

def transform(self, X, weak=False, return_df=False):
def transform(self, X, weak=None, return_df=None):
"""
Reduces the input X to the features selected by Boruta.

Expand All @@ -232,23 +238,49 @@ def transform(self, X, weak=False, return_df=False):
X : array-like, shape = [n_samples, n_features]
The training input samples.

weak: boolean, default = False
If set to true, the tentative features are also used to reduce X.

return_df : boolean, default = False
If ``X`` if a pandas dataframe and this parameter is set to True,
the transformed data will also be a dataframe.
weak : boolean, optional
Deprecated. Set ``weak`` in the constructor instead.

Returns
-------
X : array-like, shape = [n_samples, n_features_]
The input matrix X's columns are reduced to the features which were
selected by Boruta.
return_df : boolean, optional
Deprecated. Output type now follows scikit-learn's standard
``set_output``/``set_config`` mechanism.
"""

return self._transform(X, weak, return_df)

def fit_transform(self, X, y, weak=False, return_df=False):
prev_weak = self.weak
if weak is not None:
warnings.warn(
"`weak` is deprecated and will be removed in a future release. "
"Set `weak` in the constructor instead.",
FutureWarning,
stacklevel=2,
)
self.weak = weak
requested_transform = None
prev_output_config = None
force_numpy = return_df is False
if return_df is not None:
warnings.warn(
"`return_df` is deprecated and will be removed in a future "
"release. Use scikit-learn's `set_output(transform='pandas')` "
"or `set_config(transform_output='pandas')` instead.",
FutureWarning,
stacklevel=2,
)
prev_output_config = _get_output_config("transform", estimator=self)["dense"]
requested_transform = "pandas" if return_df else "default"
if prev_output_config != requested_transform:
self.set_output(transform=requested_transform)
try:
result = super().transform(X)
finally:
if weak is not None:
self.weak = prev_weak
if requested_transform is not None and prev_output_config != requested_transform:
self.set_output(transform=prev_output_config)
if force_numpy and hasattr(result, "to_numpy"):
result = result.to_numpy()
return result

def fit_transform(self, X, y=None, **fit_params):
"""
Fits Boruta, then reduces the input X to the selected features.

Expand All @@ -259,31 +291,10 @@ def fit_transform(self, X, y, weak=False, return_df=False):

y : array-like, shape = [n_samples]
The target values.

weak: boolean, default = False
If set to true, the tentative features are also used to reduce X.

return_df : boolean, default = False
If ``X`` if a pandas dataframe and this parameter is set to True,
the transformed data will also be a dataframe.

Returns
-------
X : array-like, shape = [n_samples, n_features_]
The input matrix X's columns are reduced to the features which were
selected by Boruta.
"""

self._fit(X, y)
return self._transform(X, weak, return_df)

def _validate_pandas_input(self, arg):
try:
return arg.values
except AttributeError:
raise ValueError(
"input needs to be a numpy array or pandas data frame."
)
weak = fit_params.pop("weak", None)
return_df = fit_params.pop("return_df", None)
return self.fit(X, y, **fit_params).transform(X, weak=weak, return_df=return_df)

def _fit(self, X, y):
# check input params
Expand All @@ -295,10 +306,7 @@ def _fit(self, X, y):
else:
self.feature_names_in_ = None

if not isinstance(X, np.ndarray):
X = self._validate_pandas_input(X)
if not isinstance(y, np.ndarray):
y = self._validate_pandas_input(y)
X, y = check_X_y(X, y, accept_sparse=False, ensure_2d=True, dtype=None, estimator=self)

self.n_features_in_ = X.shape[1]

Expand Down Expand Up @@ -446,24 +454,6 @@ def _fit(self, X, y):
self._print_results(dec_reg, _iter, 1)
return self

def _transform(self, X, weak=False, return_df=False):
# sanity check
try:
self.ranking_
except AttributeError:
raise ValueError('You need to call the fit(X, y) method first.')

if weak:
indices = self.support_ + self.support_weak_
else:
indices = self.support_

if return_df:
X = X.iloc[:, indices]
else:
X = X[:, indices]
return X

def _set_n_estimators(self, n_estimators):
try:
self.estimator.set_params(n_estimators=n_estimators)
Expand All @@ -476,7 +466,9 @@ def _set_n_estimators(self, n_estimators):
return self

def _get_support_mask(self):
check_is_fitted(self, 'support_')
check_is_fitted(self, ['support_', 'support_weak_'])
if self.weak:
return np.logical_or(self.support_, self.support_weak_)
return self.support_

def _get_tree_num(self, n_feat):
Expand Down
79 changes: 77 additions & 2 deletions boruta/test/test_boruta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import re

import numpy as np
import pandas as pd
import pytest
from sklearn import config_context
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
Expand Down Expand Up @@ -65,8 +68,80 @@ def test_dataframe_is_returned(Xy):
X_df, y_df = pd.DataFrame(X), pd.Series(y)
rfc = RandomForestClassifier()
bt = BorutaPy(rfc)
bt.fit(X_df, y_df)
assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame)
with config_context(transform_output="pandas"):
bt.fit(X_df, y_df)
transformed = bt.transform(X_df)
assert isinstance(transformed, pd.DataFrame)


def test_return_df_parameter_emits_warning(Xy):
X, y = Xy
X_df, y_df = pd.DataFrame(X), pd.Series(y)
bt = BorutaPy(RandomForestClassifier())
with config_context(transform_output="pandas"):
bt.fit(X_df, y_df)
with pytest.warns(FutureWarning, match=re.escape("`set_output(transform='pandas')`")):
transformed = bt.transform(X_df, return_df=True)
assert isinstance(transformed, pd.DataFrame)


def test_return_df_true_temporarily_enables_pandas_output(Xy):
X, y = Xy
bt = BorutaPy(RandomForestClassifier())
bt.fit(X, y)

baseline = bt.transform(X)
assert isinstance(baseline, np.ndarray)

with pytest.warns(FutureWarning, match="`return_df` is deprecated"):
transformed = bt.transform(X, return_df=True)
assert isinstance(transformed, pd.DataFrame)

reverted = bt.transform(X)
assert isinstance(reverted, np.ndarray)


def test_return_df_false_with_dataframe_input_returns_numpy(Xy):
X, y = Xy
X_df = pd.DataFrame(X)
bt = BorutaPy(RandomForestClassifier())
bt.fit(X_df, y)

with pytest.warns(FutureWarning, match="`return_df` is deprecated"):
transformed = bt.transform(X_df, return_df=False)
assert isinstance(transformed, np.ndarray)


def test_weak_attribute_controls_support_mask(Xy):
X, y = Xy
bt = BorutaPy(RandomForestClassifier(), weak=True)
bt.fit(X, y)

union_mask = bt.support_ | bt.support_weak_
assert np.array_equal(bt.get_support(), union_mask)


def test_transform_with_weak_parameter_is_deprecated(Xy):
X, y = Xy
bt = BorutaPy(RandomForestClassifier())
bt.fit(X, y)
bt.support_[5] = False
bt.support_weak_[5] = True

with pytest.warns(FutureWarning, match=re.escape("`weak` is deprecated")):
transformed = bt.transform(X, weak=True)

expected_features = np.count_nonzero(bt.support_ | bt.support_weak_)
assert transformed.shape[1] == expected_features


def test_fit_transform_with_weak_parameter_is_deprecated(Xy):
X, y = Xy
bt = BorutaPy(RandomForestClassifier())
with pytest.warns(FutureWarning, match=re.escape("`weak` is deprecated")):
transformed = bt.fit_transform(X, y, weak=True)
expected_features = np.count_nonzero(bt.support_ | bt.support_weak_)
assert transformed.shape[1] == expected_features


def test_selector_mixin_get_support_requires_fit():
Expand Down