Skip to content

Commit 3071a4f

Browse files
feat: new coherence metrics (#118)
1 parent 1f2ff24 commit 3071a4f

File tree

10 files changed

+992
-173
lines changed

10 files changed

+992
-173
lines changed

mostlyai/qa/_accuracy.py

Lines changed: 99 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import hashlib
1717
import logging
1818
import math
19-
from typing import Any
19+
from typing import Any, Literal
2020
from collections.abc import Callable, Iterable
2121

2222
import fastcluster
@@ -196,22 +196,27 @@ def calculate_accuracy(trn_bin_cols: pd.DataFrame, syn_bin_cols: pd.DataFrame) -
196196
that can be expected due to the sampling noise.
197197
"""
198198

199-
# create relative frequency tables for `trn` and `syn`
200-
trn_freq = trn_bin_cols.value_counts(normalize=True)
201-
syn_freq = syn_bin_cols.value_counts(normalize=True)
199+
trn_bin_cnts = trn_bin_cols.value_counts()
200+
syn_bin_cnts = syn_bin_cols.value_counts()
201+
return calculate_accuracy_cnts(trn_bin_cnts, syn_bin_cnts)
202+
203+
204+
def calculate_accuracy_cnts(trn_bin_cnts: pd.Series, syn_bin_cnts: pd.Series) -> tuple[np.float64, np.float64]:
205+
n_trn = trn_bin_cnts.sum()
206+
n_syn = syn_bin_cnts.sum()
207+
trn_freq = trn_bin_cnts / n_trn
208+
syn_freq = syn_bin_cnts / n_syn
202209
freq = pd.merge(
203210
trn_freq.to_frame("tgt").reset_index(),
204211
syn_freq.to_frame("syn").reset_index(),
205212
how="outer",
206-
on=list(trn_bin_cols.columns),
213+
on=list(trn_bin_cnts.index.names),
207214
)
208215
freq["tgt"] = freq["tgt"].fillna(0.0)
209216
freq["syn"] = freq["syn"].fillna(0.0)
210217
# calculate L1 distance between `trn` and `syn`
211218
observed_l1 = (freq["tgt"] - freq["syn"]).abs().sum()
212219
# calculated expected L1 distance based on `trn`
213-
n_trn = trn_bin_cols.shape[0]
214-
n_syn = syn_bin_cols.shape[0]
215220
expected_l1 = calculate_expected_l1_multinomial(freq["tgt"].to_list(), n_trn, n_syn)
216221
# convert to accuracy; trim superfluous precision
217222
observed_acc = (1 - observed_l1 / 2).round(5)
@@ -413,14 +418,14 @@ def plot_store_univariate(
413418
workspace: TemporaryWorkspace,
414419
) -> None:
415420
fig = plot_univariate(
416-
col,
417-
trn_num_kde,
418-
syn_num_kde,
419-
trn_cat_col_cnts,
420-
syn_cat_col_cnts,
421-
trn_bin_col_cnts,
422-
syn_bin_col_cnts,
423-
accuracy,
421+
col_name=col,
422+
trn_num_kde=trn_num_kde,
423+
syn_num_kde=syn_num_kde,
424+
trn_cat_col_cnts=trn_cat_col_cnts,
425+
syn_cat_col_cnts=syn_cat_col_cnts,
426+
trn_bin_col_cnts=trn_bin_col_cnts,
427+
syn_bin_col_cnts=syn_bin_col_cnts,
428+
accuracy=accuracy,
424429
)
425430
workspace.store_figure_html(fig, "univariate", col)
426431

@@ -433,7 +438,11 @@ def plot_univariate(
433438
syn_cat_col_cnts: pd.Series | None,
434439
trn_bin_col_cnts: pd.Series,
435440
syn_bin_col_cnts: pd.Series,
436-
accuracy: float | None,
441+
trn_cnt: int | None = None,
442+
syn_cnt: int | None = None,
443+
accuracy: float | None = None,
444+
sort_categorical_binned_by_frequency: bool = True,
445+
max_label_length: int = 10,
437446
) -> go.Figure:
438447
# either numerical/datetime KDEs or categorical counts must be provided
439448

@@ -480,13 +489,27 @@ def plot_univariate(
480489
is_numeric = trn_num_kde is not None
481490
if is_numeric:
482491
trn_line1, syn_line1 = plot_univariate_distribution_numeric(trn_num_kde, syn_num_kde)
483-
trn_line2, syn_line2 = plot_univariate_binned(trn_bin_col_cnts, syn_bin_col_cnts, sort_by_frequency=False)
492+
trn_line2, syn_line2 = plot_univariate_binned(
493+
trn_bin_col_cnts,
494+
syn_bin_col_cnts,
495+
sort_by_frequency=False,
496+
trn_cnt=trn_cnt,
497+
syn_cnt=syn_cnt,
498+
)
484499
# prevent Plotly from trying to convert strings to dates
485500
fig.layout.xaxis2.update(type="category")
486501
else:
487502
fig.layout.yaxis.update(tickformat=".0%")
488-
trn_line1, syn_line1 = plot_univariate_distribution_categorical(trn_cat_col_cnts, syn_cat_col_cnts)
489-
trn_line2, syn_line2 = plot_univariate_binned(trn_bin_col_cnts, syn_bin_col_cnts, sort_by_frequency=True)
503+
trn_line1, syn_line1 = plot_univariate_distribution_categorical(
504+
trn_cat_col_cnts, syn_cat_col_cnts, trn_cnt, syn_cnt, max_label_length=max_label_length
505+
)
506+
trn_line2, syn_line2 = plot_univariate_binned(
507+
trn_bin_col_cnts,
508+
syn_bin_col_cnts,
509+
sort_by_frequency=sort_categorical_binned_by_frequency,
510+
trn_cnt=trn_cnt,
511+
syn_cnt=syn_cnt,
512+
)
490513
# prevent Plotly from trying to convert strings to dates
491514
fig.layout.xaxis.update(type="category")
492515
fig.layout.xaxis2.update(type="category")
@@ -505,62 +528,71 @@ def plot_univariate(
505528
def prepare_categorical_plot_data_distribution(
506529
trn_col_cnts: pd.Series,
507530
syn_col_cnts: pd.Series,
531+
trn_cnt: int | None = None,
532+
syn_cnt: int | None = None,
508533
) -> pd.DataFrame:
509534
trn_col_cnts_idx = trn_col_cnts.index.to_series().astype("string").fillna(NA_BIN).replace("", EMPTY_BIN)
510535
syn_col_cnts_idx = syn_col_cnts.index.to_series().astype("string").fillna(NA_BIN).replace("", EMPTY_BIN)
511536
trn_col_cnts = trn_col_cnts.set_axis(trn_col_cnts_idx)
512537
syn_col_cnts = syn_col_cnts.set_axis(syn_col_cnts_idx)
513-
t = trn_col_cnts.to_frame("target_cnt").reset_index()
514-
s = syn_col_cnts.to_frame("synthetic_cnt").reset_index()
515-
df = pd.merge(t, s, on="index", how="outer")
538+
t = trn_col_cnts.to_frame("target_cnt").reset_index(names="category")
539+
s = syn_col_cnts.to_frame("synthetic_cnt").reset_index(names="category")
540+
df = pd.merge(t, s, on="category", how="outer")
516541
df["target_cnt"] = df["target_cnt"].fillna(0.0)
517542
df["synthetic_cnt"] = df["synthetic_cnt"].fillna(0.0)
518543
df["avg_cnt"] = (df["target_cnt"] + df["synthetic_cnt"]) / 2
519544
df = df[df["avg_cnt"] > 0]
520-
df["target_pct"] = df["target_cnt"] / df["target_cnt"].sum()
521-
df["synthetic_pct"] = df["synthetic_cnt"] / df["synthetic_cnt"].sum()
522-
df = df.rename(columns={"index": "category"})
523-
if df["category"].dtype.name == "category":
524-
df["category_code"] = df["category"].cat.codes
525-
else:
526-
df["category_code"] = df["category"]
527-
df = df.sort_values("category_code", ascending=True).reset_index(drop=True)
545+
trn_cnt = trn_cnt or df["target_cnt"].sum()
546+
syn_cnt = syn_cnt or df["synthetic_cnt"].sum()
547+
df["target_pct"] = df["target_cnt"] / trn_cnt
548+
df["synthetic_pct"] = df["synthetic_cnt"] / syn_cnt
549+
df = df.sort_values("avg_cnt", ascending=False).reset_index(drop=True)
528550
return df
529551

530552

531553
def prepare_categorical_plot_data_binned(
532554
trn_bin_col_cnts: pd.Series,
533555
syn_bin_col_cnts: pd.Series,
534556
sort_by_frequency: bool,
557+
trn_cnt: int | None = None,
558+
syn_cnt: int | None = None,
535559
) -> pd.DataFrame:
536560
t = trn_bin_col_cnts.to_frame("target_cnt").reset_index(names="category")
537561
s = syn_bin_col_cnts.to_frame("synthetic_cnt").reset_index(names="category")
538-
df = pd.merge(t, s, on="category", how="outer")
562+
df = pd.merge(t, s, on="category", how="left")
563+
df = df.set_index("category").reindex(t["category"]).reset_index()
564+
missing_s = s[~s["category"].isin(t["category"])]
565+
if not missing_s.empty:
566+
df = pd.concat([df, missing_s], ignore_index=True)
539567
df["target_cnt"] = df["target_cnt"].fillna(0.0)
540568
df["synthetic_cnt"] = df["synthetic_cnt"].fillna(0.0)
541569
df["avg_cnt"] = (df["target_cnt"] + df["synthetic_cnt"]) / 2
542570
df = df[df["avg_cnt"] > 0]
543-
df["target_pct"] = df["target_cnt"] / df["target_cnt"].sum()
544-
df["synthetic_pct"] = df["synthetic_cnt"] / df["synthetic_cnt"].sum()
545-
if df["category"].dtype.name == "category":
546-
df["category_code"] = df["category"].cat.codes
547-
else:
548-
df["category_code"] = df["category"]
571+
trn_cnt = trn_cnt or df["target_cnt"].sum()
572+
syn_cnt = syn_cnt or df["synthetic_cnt"].sum()
573+
df["target_pct"] = df["target_cnt"] / trn_cnt
574+
df["synthetic_pct"] = df["synthetic_cnt"] / syn_cnt
575+
cat_order = list(t["category"])
576+
cat_order.extend([syn_cat for syn_cat in s["category"] if syn_cat not in cat_order])
577+
df["category_order"] = df["category"].map({cat: i for i, cat in enumerate(cat_order)})
549578
if sort_by_frequency:
550579
df = df.sort_values("target_pct", ascending=False).reset_index(drop=True)
551580
else:
552-
df = df.sort_values("category_code", ascending=True).reset_index(drop=True)
581+
df = df.sort_values("category_order", ascending=True).reset_index(drop=True)
553582
return df
554583

555584

556585
def plot_univariate_distribution_categorical(
557-
trn_cat_col_cnts: pd.Series, syn_cat_col_cnts: pd.Series
586+
trn_cat_col_cnts: pd.Series,
587+
syn_cat_col_cnts: pd.Series,
588+
trn_cnt: int | None = None,
589+
syn_cnt: int | None = None,
590+
max_label_length: int = 10,
558591
) -> tuple[go.Scatter, go.Scatter]:
559592
# prepare data
560-
df = prepare_categorical_plot_data_distribution(trn_cat_col_cnts, syn_cat_col_cnts)
561-
df = df.sort_values("avg_cnt", ascending=False)
593+
df = prepare_categorical_plot_data_distribution(trn_cat_col_cnts, syn_cat_col_cnts, trn_cnt, syn_cnt)
562594
# trim labels
563-
df["category"] = trim_labels(df["category"], max_length=10)
595+
df["category"] = trim_labels(df["category"], max_length=max_label_length)
564596
# prepare plots
565597
trn_line = go.Scatter(
566598
mode="lines",
@@ -587,9 +619,11 @@ def plot_univariate_binned(
587619
trn_bin_col_cnts: pd.Series,
588620
syn_bin_col_cnts: pd.Series,
589621
sort_by_frequency: bool = False,
622+
trn_cnt: int | None = None,
623+
syn_cnt: int | None = None,
590624
) -> tuple[go.Scatter, go.Scatter]:
591625
# prepare data
592-
df = prepare_categorical_plot_data_binned(trn_bin_col_cnts, syn_bin_col_cnts, sort_by_frequency)
626+
df = prepare_categorical_plot_data_binned(trn_bin_col_cnts, syn_bin_col_cnts, sort_by_frequency, trn_cnt, syn_cnt)
593627
# prepare plots
594628
trn_line = go.Scatter(
595629
mode="lines+markers",
@@ -941,7 +975,11 @@ def binning_data(
941975
return trn_bin, syn_bin
942976

943977

944-
def bin_data(df: pd.DataFrame, bins: int | dict[str, list]) -> tuple[pd.DataFrame, dict[str, list]]:
978+
def bin_data(
979+
df: pd.DataFrame,
980+
bins: int | dict[str, list],
981+
non_categorical_label_style: Literal["bounded", "unbounded"] = "unbounded",
982+
) -> tuple[pd.DataFrame, dict[str, list]]:
945983
"""
946984
Splits data into bins.
947985
Binning algorithm depends on column type. Categorical binning creates 'n' bins corresponding to the highest
@@ -962,20 +1000,20 @@ def bin_data(df: pd.DataFrame, bins: int | dict[str, list]) -> tuple[pd.DataFram
9621000
cat_cols = [c for c in df.columns if c not in num_cols + dat_cols]
9631001
if isinstance(bins, int):
9641002
for col in num_cols:
965-
cols[col], bins_dct[col] = bin_numeric(df[col], bins)
1003+
cols[col], bins_dct[col] = bin_numeric(df[col], bins, label_style=non_categorical_label_style)
9661004
for col in dat_cols:
967-
cols[col], bins_dct[col] = bin_datetime(df[col], bins)
1005+
cols[col], bins_dct[col] = bin_datetime(df[col], bins, label_style=non_categorical_label_style)
9681006
for col in cat_cols:
9691007
cols[col], bins_dct[col] = bin_categorical(df[col], bins)
9701008
else: # bins is a dict
9711009
for col in num_cols:
9721010
if col in bins:
973-
cols[col], _ = bin_numeric(df[col], bins[col])
1011+
cols[col], _ = bin_numeric(df[col], bins[col], label_style=non_categorical_label_style)
9741012
else:
9751013
_LOG.warning(f"'{col}' is missing in bins")
9761014
for col in dat_cols:
9771015
if col in bins:
978-
cols[col], _ = bin_datetime(df[col], bins[col])
1016+
cols[col], _ = bin_datetime(df[col], bins[col], label_style=non_categorical_label_style)
9791017
else:
9801018
_LOG.warning(f"'{col}' is missing in bins")
9811019
for col in cat_cols:
@@ -987,7 +1025,9 @@ def bin_data(df: pd.DataFrame, bins: int | dict[str, list]) -> tuple[pd.DataFram
9871025
return pd.DataFrame(cols), bins_dct
9881026

9891027

990-
def bin_numeric(col: pd.Series, bins: int | list[str]) -> tuple[pd.Categorical, list]:
1028+
def bin_numeric(
1029+
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded"] = "unbounded"
1030+
) -> tuple[pd.Categorical, list]:
9911031
def _clip(col, bins):
9921032
if isinstance(bins, list):
9931033
# use precomputed bin boundaries
@@ -1031,10 +1071,12 @@ def _define_labels(breaks):
10311071
def _adjust_breaks(breaks):
10321072
return breaks[:-1] + [breaks[-1] + 1]
10331073

1034-
return bin_non_categorical(col, bins, _clip, _define_labels, _adjust_breaks)
1074+
return bin_non_categorical(col, bins, _clip, _define_labels, _adjust_breaks, label_style=label_style)
10351075

10361076

1037-
def bin_datetime(col: pd.Series, bins: int | list[str]) -> tuple[pd.Categorical, list]:
1077+
def bin_datetime(
1078+
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded"] = "unbounded"
1079+
) -> tuple[pd.Categorical, list]:
10381080
def _clip(col, bins):
10391081
if isinstance(bins, list):
10401082
# use precomputed bin boundaries
@@ -1077,7 +1119,7 @@ def _define_labels(breaks):
10771119
def _adjust_breaks(breaks):
10781120
return breaks[:-1] + [max(breaks[-1] + np.timedelta64(1, "D"), breaks[-1])]
10791121

1080-
return bin_non_categorical(col, bins, _clip, _define_labels, _adjust_breaks)
1122+
return bin_non_categorical(col, bins, _clip, _define_labels, _adjust_breaks, label_style=label_style)
10811123

10821124

10831125
def bin_non_categorical(
@@ -1086,6 +1128,7 @@ def bin_non_categorical(
10861128
clip_and_breaks: Callable,
10871129
create_labels: Callable,
10881130
adjust_breaks: Callable,
1131+
label_style: Literal["bounded", "unbounded"] = "unbounded",
10891132
) -> tuple[pd.Categorical, list]:
10901133
col = col.fillna(np.nan).infer_objects(copy=False)
10911134

@@ -1104,7 +1147,10 @@ def bin_non_categorical(
11041147
)
11051148
labels = [str(b) for b in breaks[:-1]]
11061149

1107-
new_labels_map = {label: f"⪰ {label}" for label in labels}
1150+
if label_style == "unbounded":
1151+
new_labels_map = {label: f"⪰ {label}" for label in labels}
1152+
else: # label_style == "bounded"
1153+
new_labels_map = {label: f"⪰ {label}{next_label}" for label, next_label in zip(labels, labels[1:] + ["∞"])}
11081154

11091155
bin_col = pd.cut(col, bins=adjust_breaks(breaks), labels=labels, right=False)
11101156
bin_col = bin_col.cat.rename_categories(new_labels_map)

0 commit comments

Comments
 (0)