Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _compute_metrics(
None: The computed metrics are appended to the `results` dictionary in-place.
"""
results["Slice"].append(str(region))
results["Shape"].append(df_region.shape[0])
results["Number of Records"].append(df_region.shape[0])
results["Feature"].append(feature_column)

# Check if df_region is an empty dataframe and if so, append 0 to all metrics
Expand Down Expand Up @@ -222,7 +222,7 @@ def WeakspotsDiagnosis(
thresholds = thresholds or DEFAULT_THRESHOLDS
thresholds = {k.title(): v for k, v in thresholds.items()}

results_headers = ["Slice", "Shape", "Feature"]
results_headers = ["Slice", "Number of Records", "Feature"]
results_headers.extend(metrics.keys())

figures = []
Expand All @@ -236,19 +236,20 @@ def WeakspotsDiagnosis(
feature_columns
+ [datasets[1].target_column, datasets[1].prediction_column(model)]
]

results_1 = pd.DataFrame()
results_2 = pd.DataFrame()
for feature in feature_columns:
bins = 10
if feature in datasets[0].feature_columns_categorical:
bins = len(df_1[feature].unique())
df_1["bin"] = pd.cut(df_1[feature], bins=bins)

results_1 = {k: [] for k in results_headers}
results_2 = {k: [] for k in results_headers}
r1 = {k: [] for k in results_headers}
r2 = {k: [] for k in results_headers}

for region, df_region in df_1.groupby("bin"):
_compute_metrics(
results=results_1,
results=r1,
metrics=metrics,
region=region,
df_region=df_region,
Expand All @@ -260,7 +261,7 @@ def WeakspotsDiagnosis(
(df_2[feature] > region.left) & (df_2[feature] <= region.right)
]
_compute_metrics(
results=results_2,
results=r2,
metrics=metrics,
region=region,
df_region=df_2_region,
Expand All @@ -271,8 +272,8 @@ def WeakspotsDiagnosis(

for metric in metrics.keys():
fig, df = _plot_weak_spots(
results_1=results_1,
results_2=results_2,
results_1=r1,
results_2=r2,
feature_column=feature,
metric=metric,
threshold=thresholds[metric],
Expand All @@ -284,14 +285,18 @@ def WeakspotsDiagnosis(
# rely on visual assessment for this test for now.
if not df[df[list(thresholds.keys())].lt(thresholds).any(axis=1)].empty:
passed = False
results_1 = pd.concat([results_1, pd.DataFrame(r1)])
results_2 = pd.concat([results_2, pd.DataFrame(r2)])

return (
pd.concat(
[
pd.DataFrame(results_1).assign(Dataset=datasets[0].input_id),
pd.DataFrame(results_2).assign(Dataset=datasets[1].input_id),
]
).sort_values(["Feature", "Dataset"]),
)
.reset_index(drop=True)
.sort_values(["Feature", "Dataset"]),
*figures,
passed,
)