@@ -148,6 +148,10 @@ def report(
148148 if hol_ctx_data is not None and trn_ctx_data is not None :
149149 hol_ctx_data = hol_ctx_data [trn_ctx_data .columns ]
150150
151+ # warn if dtypes are inconsistent across datasets
152+ _warn_if_dtypes_inconsistent (syn_tgt_data , trn_tgt_data , hol_tgt_data )
153+ _warn_if_dtypes_inconsistent (syn_ctx_data , trn_ctx_data , hol_ctx_data )
154+
151155 # prepare report_path
152156 if report_path is None :
153157 report_path = Path .cwd () / "model-report.html"
@@ -200,36 +204,29 @@ def report(
200204 else :
201205 setup = "1:1"
202206
203- _LOG .info ("prepare synthetic data for accuracy started" )
204- syn = pull_data_for_accuracy (
205- df_tgt = syn_tgt_data ,
206- df_ctx = syn_ctx_data ,
207+ _LOG .info ("prepare training data for accuracy started" )
208+ trn = pull_data_for_accuracy (
209+ df_tgt = trn_tgt_data ,
210+ df_ctx = trn_ctx_data ,
207211 ctx_primary_key = ctx_primary_key ,
208212 tgt_context_key = tgt_context_key ,
209213 max_sample_size = max_sample_size_accuracy ,
210214 setup = setup ,
211215 )
212216 progress .update (completed = 5 , total = 100 )
213217
214- _LOG .info ("prepare training data for accuracy started" )
215- trn = pull_data_for_accuracy (
216- df_tgt = trn_tgt_data ,
217- df_ctx = trn_ctx_data ,
218+ _LOG .info ("prepare synthetic data for accuracy started" )
219+ syn = pull_data_for_accuracy (
220+ df_tgt = syn_tgt_data ,
221+ df_ctx = syn_ctx_data ,
218222 ctx_primary_key = ctx_primary_key ,
219223 tgt_context_key = tgt_context_key ,
220224 max_sample_size = max_sample_size_accuracy ,
221225 setup = setup ,
226+ trn_dtypes = trn .dtypes .to_dict (),
222227 )
223228 progress .update (completed = 10 , total = 100 )
224229
225- # coerce dtypes to match the original training data dtypes
226- for col in trn :
227- if is_numeric_dtype (trn [col ]):
228- syn [col ] = pd .to_numeric (syn [col ], errors = "coerce" )
229- elif is_datetime64_dtype (trn [col ]):
230- syn [col ] = pd .to_datetime (syn [col ], errors = "coerce" )
231- syn [col ] = syn [col ].astype (trn [col ].dtype )
232-
233230 _LOG .info ("report accuracy and correlations" )
234231 acc_uni , acc_biv , corr_trn = _report_accuracy_and_correlations (
235232 trn = trn ,
@@ -396,6 +393,29 @@ def report(
396393 return report_path , metrics
397394
398395
396+ def _warn_if_dtypes_inconsistent (syn_df : pd .DataFrame | None , trn_df : pd .DataFrame | None , hol_df : pd .DataFrame | None ):
397+ dfs = [df for df in (syn_df , trn_df , hol_df ) if df is not None ]
398+ if not dfs :
399+ return
400+ common_columns = set .intersection (* [set (df .columns ) for df in dfs ])
401+ column_dtypes = {col : [df [col ].dtype for df in dfs ] for col in common_columns }
402+ inconsistent_columns = []
403+ for col , dtypes in column_dtypes .items ():
404+ any_datetimes = any (is_datetime64_dtype (dtype ) for dtype in dtypes )
405+ any_numbers = any (is_numeric_dtype (dtype ) for dtype in dtypes )
406+ any_others = any (not is_datetime64_dtype (dtype ) and not is_numeric_dtype (dtype ) for dtype in dtypes )
407+ if sum ([any_datetimes , any_numbers , any_others ]) > 1 :
408+ inconsistent_columns .append (col )
409+ if inconsistent_columns :
410+ warnings .warn (
411+ UserWarning (
412+ f"The column(s) { inconsistent_columns } have inconsistent data types across `syn`, `trn`, and `hol`. "
413+ "To achieve the most accurate results, please harmonize the data types of these inputs. "
414+ "Proceeding with a best-effort attempt..."
415+ )
416+ )
417+
418+
399419def _calculate_metrics (
400420 * ,
401421 acc_uni : pd .DataFrame ,
0 commit comments