Skip to content

Commit 13f8616

Browse files
feat: warn user on inconsistent dtypes (#128)
1 parent 1ea85c2 commit 13f8616

File tree

3 files changed

+89
-19
lines changed

3 files changed

+89
-19
lines changed

mostlyai/qa/_sampling.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import logging
2929
import random
3030
import time
31+
from typing import Any
32+
from pandas.core.dtypes.common import is_numeric_dtype, is_datetime64_dtype
3133

3234
import numpy as np
3335
import pandas as pd
@@ -56,6 +58,7 @@ def pull_data_for_accuracy(
5658
tgt_context_key: str | None = None,
5759
max_sample_size: int | None = None,
5860
setup: str | None = None,
61+
trn_dtypes: dict[str, str] | None = None,
5962
) -> pd.DataFrame:
6063
"""
6164
Prepare single dataset for accuracy report.
@@ -130,6 +133,14 @@ def pull_data_for_accuracy(
130133
# harmonize dtypes
131134
df = df.apply(harmonize_dtype)
132135

136+
# coerce dtypes to trn_dtypes
137+
for trn_col, trn_dtype in (trn_dtypes or {}).items():
138+
if is_numeric_dtype(trn_dtype):
139+
df[trn_col] = pd.to_numeric(df[trn_col], errors="coerce")
140+
elif is_datetime64_dtype(trn_dtype):
141+
df[trn_col] = pd.to_datetime(df[trn_col], errors="coerce")
142+
df[trn_col] = df[trn_col].astype(trn_dtype)
143+
133144
# sample tokens from text-like columns
134145
df = sample_text_tokens(df)
135146

@@ -303,10 +314,10 @@ def calculate_embeddings(
303314
def sample_text_tokens(df: pd.DataFrame) -> pd.DataFrame:
304315
tokenizer = load_tokenizer()
305316

306-
def tokenize_and_sample(text: str | None) -> str | None:
317+
def tokenize_and_sample(text: Any) -> str | None:
307318
if pd.isna(text) or text == "":
308319
return None
309-
tokens = tokenizer.tokenize(text)
320+
tokens = tokenizer.tokenize(str(text))
310321
tokens = (t.replace("Ġ", "▁") for t in tokens) # replace initial space with thick underscore
311322
return random.choice(list(tokens))
312323

@@ -337,7 +348,7 @@ def is_timestamp_dtype(x: pd.Series) -> bool:
337348
else:
338349
x = x.astype("object")
339350
except Exception:
340-
# leave dtype as-is, but just log a warning message
351+
# leave dtype as-is
341352
pass
342353
return x
343354

mostlyai/qa/reporting.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ def report(
148148
if hol_ctx_data is not None and trn_ctx_data is not None:
149149
hol_ctx_data = hol_ctx_data[trn_ctx_data.columns]
150150

151+
# warn if dtypes are inconsistent across datasets
152+
_warn_if_dtypes_inconsistent(syn_tgt_data, trn_tgt_data, hol_tgt_data)
153+
_warn_if_dtypes_inconsistent(syn_ctx_data, trn_ctx_data, hol_ctx_data)
154+
151155
# prepare report_path
152156
if report_path is None:
153157
report_path = Path.cwd() / "model-report.html"
@@ -200,36 +204,29 @@ def report(
200204
else:
201205
setup = "1:1"
202206

203-
_LOG.info("prepare synthetic data for accuracy started")
204-
syn = pull_data_for_accuracy(
205-
df_tgt=syn_tgt_data,
206-
df_ctx=syn_ctx_data,
207+
_LOG.info("prepare training data for accuracy started")
208+
trn = pull_data_for_accuracy(
209+
df_tgt=trn_tgt_data,
210+
df_ctx=trn_ctx_data,
207211
ctx_primary_key=ctx_primary_key,
208212
tgt_context_key=tgt_context_key,
209213
max_sample_size=max_sample_size_accuracy,
210214
setup=setup,
211215
)
212216
progress.update(completed=5, total=100)
213217

214-
_LOG.info("prepare training data for accuracy started")
215-
trn = pull_data_for_accuracy(
216-
df_tgt=trn_tgt_data,
217-
df_ctx=trn_ctx_data,
218+
_LOG.info("prepare synthetic data for accuracy started")
219+
syn = pull_data_for_accuracy(
220+
df_tgt=syn_tgt_data,
221+
df_ctx=syn_ctx_data,
218222
ctx_primary_key=ctx_primary_key,
219223
tgt_context_key=tgt_context_key,
220224
max_sample_size=max_sample_size_accuracy,
221225
setup=setup,
226+
trn_dtypes=trn.dtypes.to_dict(),
222227
)
223228
progress.update(completed=10, total=100)
224229

225-
# coerce dtypes to match the original training data dtypes
226-
for col in trn:
227-
if is_numeric_dtype(trn[col]):
228-
syn[col] = pd.to_numeric(syn[col], errors="coerce")
229-
elif is_datetime64_dtype(trn[col]):
230-
syn[col] = pd.to_datetime(syn[col], errors="coerce")
231-
syn[col] = syn[col].astype(trn[col].dtype)
232-
233230
_LOG.info("report accuracy and correlations")
234231
acc_uni, acc_biv, corr_trn = _report_accuracy_and_correlations(
235232
trn=trn,
@@ -396,6 +393,29 @@ def report(
396393
return report_path, metrics
397394

398395

396+
def _warn_if_dtypes_inconsistent(syn_df: pd.DataFrame | None, trn_df: pd.DataFrame | None, hol_df: pd.DataFrame | None):
397+
dfs = [df for df in (syn_df, trn_df, hol_df) if df is not None]
398+
if not dfs:
399+
return
400+
common_columns = set.intersection(*[set(df.columns) for df in dfs])
401+
column_dtypes = {col: [df[col].dtype for df in dfs] for col in common_columns}
402+
inconsistent_columns = []
403+
for col, dtypes in column_dtypes.items():
404+
any_datetimes = any(is_datetime64_dtype(dtype) for dtype in dtypes)
405+
any_numbers = any(is_numeric_dtype(dtype) for dtype in dtypes)
406+
any_others = any(not is_datetime64_dtype(dtype) and not is_numeric_dtype(dtype) for dtype in dtypes)
407+
if sum([any_datetimes, any_numbers, any_others]) > 1:
408+
inconsistent_columns.append(col)
409+
if inconsistent_columns:
410+
warnings.warn(
411+
UserWarning(
412+
f"The column(s) {inconsistent_columns} have inconsistent data types across `syn`, `trn`, and `hol`. "
413+
"To achieve the most accurate results, please harmonize the data types of these inputs. "
414+
"Proceeding with a best-effort attempt..."
415+
)
416+
)
417+
418+
399419
def _calculate_metrics(
400420
*,
401421
acc_uni: pd.DataFrame,

tests/end_to_end/test_report.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
import uuid
1616
from pathlib import Path
17+
import warnings
1718

1819
import pandas as pd
1920
import numpy as np
2021

2122
from mostlyai import qa
23+
from datetime import datetime, timedelta
2224

2325

2426
def mock_data(n):
@@ -278,3 +280,40 @@ def test_missing(tmp_path):
278280
trn_tgt_data=df1,
279281
)
280282
assert metrics is not None
283+
284+
285+
def test_mixed_dtypes(tmp_path):
286+
# test that datetime columns drawn from the same distribution, but having different dtype
287+
# are still yielding somewhat good results and warning is issued
288+
289+
def generate_dates(start_date, end_date, num_samples):
290+
days_range = (end_date - start_date).days
291+
return [start_date + timedelta(days=int(days)) for days in np.random.randint(0, days_range, num_samples)]
292+
293+
num_samples = 200
294+
start_date = datetime(2020, 1, 1)
295+
end_date = datetime(2023, 12, 31)
296+
df = pd.DataFrame(
297+
{
298+
"trn_dt": pd.Series(generate_dates(start_date, end_date, num_samples)).values.astype(str),
299+
"syn_dt": pd.Series(generate_dates(start_date, end_date, num_samples), dtype="datetime64[ns]"),
300+
}
301+
)
302+
trn_df, syn_df = df["trn_dt"].to_frame("dt"), df["syn_dt"].to_frame("dt")
303+
304+
with warnings.catch_warnings(record=True) as w:
305+
_, statistics = qa.report(
306+
syn_tgt_data=syn_df,
307+
trn_tgt_data=trn_df,
308+
report_path=tmp_path / "report.html",
309+
)
310+
expected_warning = (
311+
"The column(s) ['dt'] have inconsistent data types across `syn`, `trn`, and `hol`. "
312+
"To achieve the most accurate results, please harmonize the data types of these inputs. "
313+
"Proceeding with a best-effort attempt..."
314+
)
315+
assert any(expected_warning in str(warning.message) for warning in w), (
316+
"Expected a warning about dtype mismatch for column 'dt'"
317+
)
318+
assert statistics.accuracy.overall > 0.6
319+
assert 0.4 < statistics.similarity.discriminator_auc_training_synthetic < 0.6

0 commit comments

Comments
 (0)