Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ description = "ValidMind Library"
license = "Commercial License"
name = "validmind"
readme = "README.pypi.md"
version = "2.8.14"
version = "2.8.15"

[tool.poetry.dependencies]
aiohttp = {extras = ["speedups"], version = "*"}
Expand Down
2 changes: 1 addition & 1 deletion validmind/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.8.14"
__version__ = "2.8.15"
38 changes: 17 additions & 21 deletions validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,25 @@ def select_shap_values(
"""
if not isinstance(shap_values, list):
# For regression, return the SHAP values as they are
# TODO: shap_values is always an array of all predictions, how is the if above supposed to work?
# logger.info("Returning SHAP values as-is.")
return shap_values

num_classes = len(shap_values)

# Default to class 1 for binary classification where no class is specified
if num_classes == 2 and class_of_interest is None:
logger.debug("Using SHAP values for class 1 (positive class).")
return shap_values[1]
selected_values = shap_values
else:
num_classes = len(shap_values)
# Default to class 1 for binary classification where no class is specified
if num_classes == 2 and class_of_interest is None:
selected_values = shap_values[1]
# Otherwise, use the specified class_of_interest
elif class_of_interest is not None and 0 <= class_of_interest < num_classes:
selected_values = shap_values[class_of_interest]
else:
raise ValueError(
f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
)

# Otherwise, use the specified class_of_interest
if (
class_of_interest is None
or class_of_interest < 0
or class_of_interest >= num_classes
):
raise ValueError(
f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
)
# Add type conversion here to ensure proper float array
if hasattr(selected_values, "dtype"):
selected_values = np.array(selected_values, dtype=np.float64)

logger.debug(f"Using SHAP values for class {class_of_interest}.")
return shap_values[class_of_interest]
return selected_values


def generate_shap_plot(
Expand Down