1616import hashlib
1717import logging
1818import math
19- from typing import Any
19+ from typing import Any , Literal
2020from collections .abc import Callable , Iterable
2121
2222import fastcluster
@@ -196,22 +196,27 @@ def calculate_accuracy(trn_bin_cols: pd.DataFrame, syn_bin_cols: pd.DataFrame) -
196196 that can be expected due to the sampling noise.
197197 """
198198
199- # create relative frequency tables for `trn` and `syn`
200- trn_freq = trn_bin_cols .value_counts (normalize = True )
201- syn_freq = syn_bin_cols .value_counts (normalize = True )
199+ trn_bin_cnts = trn_bin_cols .value_counts ()
200+ syn_bin_cnts = syn_bin_cols .value_counts ()
201+ return calculate_accuracy_cnts (trn_bin_cnts , syn_bin_cnts )
202+
203+
204+ def calculate_accuracy_cnts (trn_bin_cnts : pd .Series , syn_bin_cnts : pd .Series ) -> tuple [np .float64 , np .float64 ]:
205+ n_trn = trn_bin_cnts .sum ()
206+ n_syn = syn_bin_cnts .sum ()
207+ trn_freq = trn_bin_cnts / n_trn
208+ syn_freq = syn_bin_cnts / n_syn
202209 freq = pd .merge (
203210 trn_freq .to_frame ("tgt" ).reset_index (),
204211 syn_freq .to_frame ("syn" ).reset_index (),
205212 how = "outer" ,
206- on = list (trn_bin_cols . columns ),
213+ on = list (trn_bin_cnts . index . names ),
207214 )
208215 freq ["tgt" ] = freq ["tgt" ].fillna (0.0 )
209216 freq ["syn" ] = freq ["syn" ].fillna (0.0 )
210217 # calculate L1 distance between `trn` and `syn`
211218 observed_l1 = (freq ["tgt" ] - freq ["syn" ]).abs ().sum ()
212219 # calculated expected L1 distance based on `trn`
213- n_trn = trn_bin_cols .shape [0 ]
214- n_syn = syn_bin_cols .shape [0 ]
215220 expected_l1 = calculate_expected_l1_multinomial (freq ["tgt" ].to_list (), n_trn , n_syn )
216221 # convert to accuracy; trim superfluous precision
217222 observed_acc = (1 - observed_l1 / 2 ).round (5 )
@@ -413,14 +418,14 @@ def plot_store_univariate(
413418 workspace : TemporaryWorkspace ,
414419) -> None :
415420 fig = plot_univariate (
416- col ,
417- trn_num_kde ,
418- syn_num_kde ,
419- trn_cat_col_cnts ,
420- syn_cat_col_cnts ,
421- trn_bin_col_cnts ,
422- syn_bin_col_cnts ,
423- accuracy ,
421+ col_name = col ,
422+ trn_num_kde = trn_num_kde ,
423+ syn_num_kde = syn_num_kde ,
424+ trn_cat_col_cnts = trn_cat_col_cnts ,
425+ syn_cat_col_cnts = syn_cat_col_cnts ,
426+ trn_bin_col_cnts = trn_bin_col_cnts ,
427+ syn_bin_col_cnts = syn_bin_col_cnts ,
428+ accuracy = accuracy ,
424429 )
425430 workspace .store_figure_html (fig , "univariate" , col )
426431
@@ -433,7 +438,11 @@ def plot_univariate(
433438 syn_cat_col_cnts : pd .Series | None ,
434439 trn_bin_col_cnts : pd .Series ,
435440 syn_bin_col_cnts : pd .Series ,
436- accuracy : float | None ,
441+ trn_cnt : int | None = None ,
442+ syn_cnt : int | None = None ,
443+ accuracy : float | None = None ,
444+ sort_categorical_binned_by_frequency : bool = True ,
445+ max_label_length : int = 10 ,
437446) -> go .Figure :
438447 # either numerical/datetime KDEs or categorical counts must be provided
439448
@@ -480,13 +489,27 @@ def plot_univariate(
480489 is_numeric = trn_num_kde is not None
481490 if is_numeric :
482491 trn_line1 , syn_line1 = plot_univariate_distribution_numeric (trn_num_kde , syn_num_kde )
483- trn_line2 , syn_line2 = plot_univariate_binned (trn_bin_col_cnts , syn_bin_col_cnts , sort_by_frequency = False )
492+ trn_line2 , syn_line2 = plot_univariate_binned (
493+ trn_bin_col_cnts ,
494+ syn_bin_col_cnts ,
495+ sort_by_frequency = False ,
496+ trn_cnt = trn_cnt ,
497+ syn_cnt = syn_cnt ,
498+ )
484499 # prevent Plotly from trying to convert strings to dates
485500 fig .layout .xaxis2 .update (type = "category" )
486501 else :
487502 fig .layout .yaxis .update (tickformat = ".0%" )
488- trn_line1 , syn_line1 = plot_univariate_distribution_categorical (trn_cat_col_cnts , syn_cat_col_cnts )
489- trn_line2 , syn_line2 = plot_univariate_binned (trn_bin_col_cnts , syn_bin_col_cnts , sort_by_frequency = True )
503+ trn_line1 , syn_line1 = plot_univariate_distribution_categorical (
504+ trn_cat_col_cnts , syn_cat_col_cnts , trn_cnt , syn_cnt , max_label_length = max_label_length
505+ )
506+ trn_line2 , syn_line2 = plot_univariate_binned (
507+ trn_bin_col_cnts ,
508+ syn_bin_col_cnts ,
509+ sort_by_frequency = sort_categorical_binned_by_frequency ,
510+ trn_cnt = trn_cnt ,
511+ syn_cnt = syn_cnt ,
512+ )
490513 # prevent Plotly from trying to convert strings to dates
491514 fig .layout .xaxis .update (type = "category" )
492515 fig .layout .xaxis2 .update (type = "category" )
@@ -505,62 +528,71 @@ def plot_univariate(
505528def prepare_categorical_plot_data_distribution (
506529 trn_col_cnts : pd .Series ,
507530 syn_col_cnts : pd .Series ,
531+ trn_cnt : int | None = None ,
532+ syn_cnt : int | None = None ,
508533) -> pd .DataFrame :
509534 trn_col_cnts_idx = trn_col_cnts .index .to_series ().astype ("string" ).fillna (NA_BIN ).replace ("" , EMPTY_BIN )
510535 syn_col_cnts_idx = syn_col_cnts .index .to_series ().astype ("string" ).fillna (NA_BIN ).replace ("" , EMPTY_BIN )
511536 trn_col_cnts = trn_col_cnts .set_axis (trn_col_cnts_idx )
512537 syn_col_cnts = syn_col_cnts .set_axis (syn_col_cnts_idx )
513- t = trn_col_cnts .to_frame ("target_cnt" ).reset_index ()
514- s = syn_col_cnts .to_frame ("synthetic_cnt" ).reset_index ()
515- df = pd .merge (t , s , on = "index " , how = "outer" )
538+ t = trn_col_cnts .to_frame ("target_cnt" ).reset_index (names = "category" )
539+ s = syn_col_cnts .to_frame ("synthetic_cnt" ).reset_index (names = "category" )
540+ df = pd .merge (t , s , on = "category " , how = "outer" )
516541 df ["target_cnt" ] = df ["target_cnt" ].fillna (0.0 )
517542 df ["synthetic_cnt" ] = df ["synthetic_cnt" ].fillna (0.0 )
518543 df ["avg_cnt" ] = (df ["target_cnt" ] + df ["synthetic_cnt" ]) / 2
519544 df = df [df ["avg_cnt" ] > 0 ]
520- df ["target_pct" ] = df ["target_cnt" ] / df ["target_cnt" ].sum ()
521- df ["synthetic_pct" ] = df ["synthetic_cnt" ] / df ["synthetic_cnt" ].sum ()
522- df = df .rename (columns = {"index" : "category" })
523- if df ["category" ].dtype .name == "category" :
524- df ["category_code" ] = df ["category" ].cat .codes
525- else :
526- df ["category_code" ] = df ["category" ]
527- df = df .sort_values ("category_code" , ascending = True ).reset_index (drop = True )
545+ trn_cnt = trn_cnt or df ["target_cnt" ].sum ()
546+ syn_cnt = syn_cnt or df ["synthetic_cnt" ].sum ()
547+ df ["target_pct" ] = df ["target_cnt" ] / trn_cnt
548+ df ["synthetic_pct" ] = df ["synthetic_cnt" ] / syn_cnt
549+ df = df .sort_values ("avg_cnt" , ascending = False ).reset_index (drop = True )
528550 return df
529551
530552
531553def prepare_categorical_plot_data_binned (
532554 trn_bin_col_cnts : pd .Series ,
533555 syn_bin_col_cnts : pd .Series ,
534556 sort_by_frequency : bool ,
557+ trn_cnt : int | None = None ,
558+ syn_cnt : int | None = None ,
535559) -> pd .DataFrame :
536560 t = trn_bin_col_cnts .to_frame ("target_cnt" ).reset_index (names = "category" )
537561 s = syn_bin_col_cnts .to_frame ("synthetic_cnt" ).reset_index (names = "category" )
538- df = pd .merge (t , s , on = "category" , how = "outer" )
562+ df = pd .merge (t , s , on = "category" , how = "left" )
563+ df = df .set_index ("category" ).reindex (t ["category" ]).reset_index ()
564+ missing_s = s [~ s ["category" ].isin (t ["category" ])]
565+ if not missing_s .empty :
566+ df = pd .concat ([df , missing_s ], ignore_index = True )
539567 df ["target_cnt" ] = df ["target_cnt" ].fillna (0.0 )
540568 df ["synthetic_cnt" ] = df ["synthetic_cnt" ].fillna (0.0 )
541569 df ["avg_cnt" ] = (df ["target_cnt" ] + df ["synthetic_cnt" ]) / 2
542570 df = df [df ["avg_cnt" ] > 0 ]
543- df ["target_pct" ] = df ["target_cnt" ] / df ["target_cnt" ].sum ()
544- df ["synthetic_pct" ] = df ["synthetic_cnt" ] / df ["synthetic_cnt" ].sum ()
545- if df ["category" ].dtype .name == "category" :
546- df ["category_code" ] = df ["category" ].cat .codes
547- else :
548- df ["category_code" ] = df ["category" ]
571+ trn_cnt = trn_cnt or df ["target_cnt" ].sum ()
572+ syn_cnt = syn_cnt or df ["synthetic_cnt" ].sum ()
573+ df ["target_pct" ] = df ["target_cnt" ] / trn_cnt
574+ df ["synthetic_pct" ] = df ["synthetic_cnt" ] / syn_cnt
575+ cat_order = list (t ["category" ])
576+ cat_order .extend ([syn_cat for syn_cat in s ["category" ] if syn_cat not in cat_order ])
577+ df ["category_order" ] = df ["category" ].map ({cat : i for i , cat in enumerate (cat_order )})
549578 if sort_by_frequency :
550579 df = df .sort_values ("target_pct" , ascending = False ).reset_index (drop = True )
551580 else :
552- df = df .sort_values ("category_code " , ascending = True ).reset_index (drop = True )
581+ df = df .sort_values ("category_order " , ascending = True ).reset_index (drop = True )
553582 return df
554583
555584
556585def plot_univariate_distribution_categorical (
557- trn_cat_col_cnts : pd .Series , syn_cat_col_cnts : pd .Series
586+ trn_cat_col_cnts : pd .Series ,
587+ syn_cat_col_cnts : pd .Series ,
588+ trn_cnt : int | None = None ,
589+ syn_cnt : int | None = None ,
590+ max_label_length : int = 10 ,
558591) -> tuple [go .Scatter , go .Scatter ]:
559592 # prepare data
560- df = prepare_categorical_plot_data_distribution (trn_cat_col_cnts , syn_cat_col_cnts )
561- df = df .sort_values ("avg_cnt" , ascending = False )
593+ df = prepare_categorical_plot_data_distribution (trn_cat_col_cnts , syn_cat_col_cnts , trn_cnt , syn_cnt )
562594 # trim labels
563- df ["category" ] = trim_labels (df ["category" ], max_length = 10 )
595+ df ["category" ] = trim_labels (df ["category" ], max_length = max_label_length )
564596 # prepare plots
565597 trn_line = go .Scatter (
566598 mode = "lines" ,
@@ -587,9 +619,11 @@ def plot_univariate_binned(
587619 trn_bin_col_cnts : pd .Series ,
588620 syn_bin_col_cnts : pd .Series ,
589621 sort_by_frequency : bool = False ,
622+ trn_cnt : int | None = None ,
623+ syn_cnt : int | None = None ,
590624) -> tuple [go .Scatter , go .Scatter ]:
591625 # prepare data
592- df = prepare_categorical_plot_data_binned (trn_bin_col_cnts , syn_bin_col_cnts , sort_by_frequency )
626+ df = prepare_categorical_plot_data_binned (trn_bin_col_cnts , syn_bin_col_cnts , sort_by_frequency , trn_cnt , syn_cnt )
593627 # prepare plots
594628 trn_line = go .Scatter (
595629 mode = "lines+markers" ,
@@ -941,7 +975,11 @@ def binning_data(
941975 return trn_bin , syn_bin
942976
943977
944- def bin_data (df : pd .DataFrame , bins : int | dict [str , list ]) -> tuple [pd .DataFrame , dict [str , list ]]:
978+ def bin_data (
979+ df : pd .DataFrame ,
980+ bins : int | dict [str , list ],
981+ non_categorical_label_style : Literal ["bounded" , "unbounded" ] = "unbounded" ,
982+ ) -> tuple [pd .DataFrame , dict [str , list ]]:
945983 """
946984 Splits data into bins.
947985 Binning algorithm depends on column type. Categorical binning creates 'n' bins corresponding to the highest
@@ -962,20 +1000,20 @@ def bin_data(df: pd.DataFrame, bins: int | dict[str, list]) -> tuple[pd.DataFram
9621000 cat_cols = [c for c in df .columns if c not in num_cols + dat_cols ]
9631001 if isinstance (bins , int ):
9641002 for col in num_cols :
965- cols [col ], bins_dct [col ] = bin_numeric (df [col ], bins )
1003+ cols [col ], bins_dct [col ] = bin_numeric (df [col ], bins , label_style = non_categorical_label_style )
9661004 for col in dat_cols :
967- cols [col ], bins_dct [col ] = bin_datetime (df [col ], bins )
1005+ cols [col ], bins_dct [col ] = bin_datetime (df [col ], bins , label_style = non_categorical_label_style )
9681006 for col in cat_cols :
9691007 cols [col ], bins_dct [col ] = bin_categorical (df [col ], bins )
9701008 else : # bins is a dict
9711009 for col in num_cols :
9721010 if col in bins :
973- cols [col ], _ = bin_numeric (df [col ], bins [col ])
1011+ cols [col ], _ = bin_numeric (df [col ], bins [col ], label_style = non_categorical_label_style )
9741012 else :
9751013 _LOG .warning (f"'{ col } ' is missing in bins" )
9761014 for col in dat_cols :
9771015 if col in bins :
978- cols [col ], _ = bin_datetime (df [col ], bins [col ])
1016+ cols [col ], _ = bin_datetime (df [col ], bins [col ], label_style = non_categorical_label_style )
9791017 else :
9801018 _LOG .warning (f"'{ col } ' is missing in bins" )
9811019 for col in cat_cols :
@@ -987,7 +1025,9 @@ def bin_data(df: pd.DataFrame, bins: int | dict[str, list]) -> tuple[pd.DataFram
9871025 return pd .DataFrame (cols ), bins_dct
9881026
9891027
990- def bin_numeric (col : pd .Series , bins : int | list [str ]) -> tuple [pd .Categorical , list ]:
1028+ def bin_numeric (
1029+ col : pd .Series , bins : int | list [str ], label_style : Literal ["bounded" , "unbounded" ] = "unbounded"
1030+ ) -> tuple [pd .Categorical , list ]:
9911031 def _clip (col , bins ):
9921032 if isinstance (bins , list ):
9931033 # use precomputed bin boundaries
@@ -1031,10 +1071,12 @@ def _define_labels(breaks):
10311071 def _adjust_breaks (breaks ):
10321072 return breaks [:- 1 ] + [breaks [- 1 ] + 1 ]
10331073
1034- return bin_non_categorical (col , bins , _clip , _define_labels , _adjust_breaks )
1074+ return bin_non_categorical (col , bins , _clip , _define_labels , _adjust_breaks , label_style = label_style )
10351075
10361076
1037- def bin_datetime (col : pd .Series , bins : int | list [str ]) -> tuple [pd .Categorical , list ]:
1077+ def bin_datetime (
1078+ col : pd .Series , bins : int | list [str ], label_style : Literal ["bounded" , "unbounded" ] = "unbounded"
1079+ ) -> tuple [pd .Categorical , list ]:
10381080 def _clip (col , bins ):
10391081 if isinstance (bins , list ):
10401082 # use precomputed bin boundaries
@@ -1077,7 +1119,7 @@ def _define_labels(breaks):
10771119 def _adjust_breaks (breaks ):
10781120 return breaks [:- 1 ] + [max (breaks [- 1 ] + np .timedelta64 (1 , "D" ), breaks [- 1 ])]
10791121
1080- return bin_non_categorical (col , bins , _clip , _define_labels , _adjust_breaks )
1122+ return bin_non_categorical (col , bins , _clip , _define_labels , _adjust_breaks , label_style = label_style )
10811123
10821124
10831125def bin_non_categorical (
@@ -1086,6 +1128,7 @@ def bin_non_categorical(
10861128 clip_and_breaks : Callable ,
10871129 create_labels : Callable ,
10881130 adjust_breaks : Callable ,
1131+ label_style : Literal ["bounded" , "unbounded" ] = "unbounded" ,
10891132) -> tuple [pd .Categorical , list ]:
10901133 col = col .fillna (np .nan ).infer_objects (copy = False )
10911134
@@ -1104,7 +1147,10 @@ def bin_non_categorical(
11041147 )
11051148 labels = [str (b ) for b in breaks [:- 1 ]]
11061149
1107- new_labels_map = {label : f"⪰ { label } " for label in labels }
1150+ if label_style == "unbounded" :
1151+ new_labels_map = {label : f"⪰ { label } " for label in labels }
1152+ else : # label_style == "bounded"
1153+ new_labels_map = {label : f"⪰ { label } ≺ { next_label } " for label , next_label in zip (labels , labels [1 :] + ["∞" ])}
11081154
11091155 bin_col = pd .cut (col , bins = adjust_breaks (breaks ), labels = labels , right = False )
11101156 bin_col = bin_col .cat .rename_categories (new_labels_map )
0 commit comments