Skip to content

Commit 3f24f61

Browse files
fix: incorrect accuracy setup calculation (#85)
1 parent 0f0d022 commit 3f24f61

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

mostlyai/qa/_sampling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,16 @@ def pull_data_for_accuracy(
110110
df = pd.merge(df, df_tgt, on=key, how="left")
111111
df = pd.merge(df, df_nxt, on=key, how="left")
112112
df = df.drop(columns=[key])
113-
114-
# remove records with sequence length equal to 0
115113
count_column = f"{TGT_COLUMN_PREFIX}{COUNT_COLUMN}"
116114
df[count_column] = df[count_column].fillna(0).astype("Int64")
117-
df = df.loc[df[count_column] > 0].reset_index(drop=True)
118115

116+
# determine setup if not provided
119117
if setup is None:
120118
setup = "1:1" if (df[count_column] == 1).all() else "1:N"
121119

120+
# remove records with sequence length equal to 0
121+
df = df.loc[df[count_column] > 0].reset_index(drop=True)
122+
122123
# for 1:1 ctx/tgt setups, drop nxt and count columns; ensure at least one column remains
123124
if setup == "1:1":
124125
df = df.drop(columns=[c for c in df.columns if c.startswith(NXT_COLUMN_PREFIX)])

mostlyai/qa/reporting_from_statistics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def report_from_statistics(
113113
ctx_primary_key=ctx_primary_key,
114114
tgt_context_key=tgt_context_key,
115115
max_sample_size=max_sample_size_accuracy,
116+
# always pull Sequence Length and nxt columns for synthetic data
117+
# and let downstream functions decide if they are needed
118+
setup="1:N",
116119
)
117120
_LOG.info(f"sample synthetic data finished ({syn.shape=})")
118121
progress.update(completed=20, total=100)

0 commit comments

Comments
 (0)