Skip to content

Commit 6b19d95

Browse files
authored
fix issue for DataFrames with mismatching keys
1 parent da90bf1 commit 6b19d95

File tree

2 files changed

+10
-24
lines changed

2 files changed

+10
-24
lines changed

mostlyai/qa/_common.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,13 @@ def determine_data_size(
132132
tgt_context_key: str | None = None,
133133
) -> int:
134134
if ctx_data is not None and ctx_primary_key is not None:
135-
return len(ctx_data[ctx_primary_key].unique())
136-
elif ctx_data is not None and not ctx_data.empty:
137-
return len(ctx_data)
135+
# consider number of matching keys for sample size
136+
ctx_keys = ctx_data[ctx_primary_key].unique()
137+
tgt_keys = tgt_data[tgt_context_key].unique()
138+
keys = set(ctx_keys).intersection(set(tgt_keys))
139+
return len(keys)
138140
elif tgt_data is not None and tgt_context_key is not None:
139-
return len(tgt_data[tgt_context_key].unique())
140-
elif tgt_data is not None and not tgt_data.empty:
141-
return len(tgt_data)
141+
tgt_keys = tgt_data[tgt_context_key].unique()
142+
return len(tgt_keys)
142143
else:
143-
return 0
144+
return len(tgt_data)

tests/end_to_end/test_report.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def make_dfs(
205205
test_dfs = [
206206
# setups with <100 rows in tgt/ctx should early terminate
207207
{"dfs": make_dfs(ctx_rows=99, tgt_rows=99, ctx_cols=["ctx_col"], tgt_cols=["tgt_col"]), "early_term": True},
208+
{"dfs": make_dfs(ctx_rows=100, tgt_rows=100, shift=90, tgt_cols=["tgt_col"]), "early_term": True},
209+
{"dfs": make_dfs(ctx_rows=100, tgt_rows=100, shift=100, tgt_cols=["tgt_col"]), "early_term": True},
208210
# other setups should produce report
209211
{"dfs": make_dfs(ctx_rows=100, tgt_rows=100), "early_term": False},
210212
{"dfs": make_dfs(ctx_rows=100, tgt_rows=100, ctx_cols=["ctx_col"], tgt_cols=["tgt_col"]), "early_term": False},
@@ -238,23 +240,6 @@ def test_report_few_holdout_records(tmp_path):
238240
assert metrics is not None
239241

240242

241-
def test_report_sequential_few_records(tmp_path):
242-
# ensure that we don't crash in case of dominant zero-seq-length
243-
ctx = pd.DataFrame({"id": list(range(1000))})
244-
tgt = pd.DataFrame({"id": [1, 2, 3, 4, 5] * 100, "col": ["a"] * 500})
245-
_, metrics = qa.report(
246-
syn_tgt_data=tgt,
247-
trn_tgt_data=tgt,
248-
hol_tgt_data=tgt,
249-
syn_ctx_data=ctx,
250-
trn_ctx_data=ctx,
251-
hol_ctx_data=ctx,
252-
tgt_context_key="id",
253-
ctx_primary_key="id",
254-
)
255-
assert metrics is not None
256-
257-
258243
def test_odd_column_names(tmp_path):
259244
values = ["a", "b"] * 50
260245
df = pd.DataFrame(

0 commit comments

Comments
 (0)