1717import time
1818
1919import numpy as np
20+ import pandas as pd
21+ from sklearn .preprocessing import QuantileTransformer
2022
2123from mostlyai .qa ._common import (
2224 CHARTS_COLORS ,
2325 CHARTS_FONTS ,
26+ EMPTY_BIN ,
27+ NA_BIN ,
28+ RARE_BIN ,
2429)
2530from mostlyai .qa ._filesystem import TemporaryWorkspace
2631from plotly import graph_objs as go
2732
33+ from mostlyai .qa .assets import load_embedder
34+ from sklearn .decomposition import PCA
35+
2836_LOG = logging .getLogger (__name__ )
2937
3038
39+ def encode_numerics (
40+ syn : pd .DataFrame , trn : pd .DataFrame , hol : pd .DataFrame | None = None
41+ ) -> tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame | None ]:
42+ """
43+ Encode numeric features by mapping this via QuantileTransformer to a uniform distribution from [-0.5, 0.5].
44+ """
45+ syn_num , trn_num , hol_num = {}, {}, {}
46+ if hol is None :
47+ hol = pd .DataFrame (columns = trn .columns )
48+ for col in trn .columns :
49+ # convert to numerics
50+ syn_num [col ] = pd .to_numeric (syn [col ], errors = "coerce" )
51+ trn_num [col ] = pd .to_numeric (trn [col ], errors = "coerce" )
52+ hol_num [col ] = pd .to_numeric (hol [col ], errors = "coerce" )
53+ # retain NAs (needed for datetime)
54+ syn_num [col ] = syn_num [col ].where (~ syn [col ].isna (), np .nan )
55+ trn_num [col ] = trn_num [col ].where (~ trn [col ].isna (), np .nan )
56+ hol_num [col ] = hol_num [col ].where (~ hol [col ].isna (), np .nan )
57+ # normalize numeric features based on trn
58+ qt_scaler = QuantileTransformer (
59+ output_distribution = "uniform" ,
60+ random_state = 42 ,
61+ n_quantiles = min (100 , len (trn ) + len (hol )),
62+ )
63+ ori_num = pd .concat ([trn_num [col ], hol_num [col ]]) if len (hol ) > 0 else pd .DataFrame (trn_num [col ])
64+ qt_scaler .fit (ori_num .values .reshape (- 1 , 1 ))
65+ syn_num [col ] = qt_scaler .transform (syn_num [col ].values .reshape (- 1 , 1 ))[:, 0 ] - 0.5
66+ trn_num [col ] = qt_scaler .transform (trn_num [col ].values .reshape (- 1 , 1 ))[:, 0 ] - 0.5
67+ hol_num [col ] = qt_scaler .transform (hol_num [col ].values .reshape (- 1 , 1 ))[:, 0 ] - 0.5 if len (hol ) > 0 else None
68+ # replace NAs with 0.0
69+ syn_num [col ] = np .nan_to_num (syn_num [col ], nan = 0.0 )
70+ trn_num [col ] = np .nan_to_num (trn_num [col ], nan = 0.0 )
71+ hol_num [col ] = np .nan_to_num (hol_num [col ], nan = 0.0 )
72+ # add extra columns for NAs
73+ if trn [col ].isna ().any () or hol [col ].isna ().any ():
74+ syn_num [col + " - N/A" ] = syn [col ].isna ().astype (float )
75+ trn_num [col + " - N/A" ] = trn [col ].isna ().astype (float )
76+ hol_num [col + " - N/A" ] = hol [col ].isna ().astype (float )
77+ syn_num = pd .DataFrame (syn_num , index = syn .index )
78+ trn_num = pd .DataFrame (trn_num , index = trn .index )
79+ hol_num = pd .DataFrame (hol_num , index = hol .index ) if len (hol ) > 0 else None
80+ return syn_num , trn_num , hol_num
81+
82+
83+ def encode_strings (
84+ syn : pd .DataFrame , trn : pd .DataFrame , hol : pd .DataFrame | None = None
85+ ) -> tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame | None ]:
86+ """
87+ Encode string features by mapping them to a low-dimensional space using PCA of their embeddings.
88+ """
89+ trn_str , syn_str , hol_str = {}, {}, {}
90+ if hol is None :
91+ hol = pd .DataFrame (columns = trn .columns )
92+ for col in trn .columns :
93+ # prepare inputs
94+ syn_col = syn [col ].astype (str ).fillna (NA_BIN ).replace ("" , EMPTY_BIN )
95+ trn_col = trn [col ].astype (str ).fillna (NA_BIN ).replace ("" , EMPTY_BIN )
96+ hol_col = hol [col ].astype (str ).fillna (NA_BIN ).replace ("" , EMPTY_BIN )
97+ # get unique original values
98+ uvals = pd .concat ([trn_col , hol_col ]).value_counts ().index .to_list ()
99+ # map out of range values to RARE_BIN
100+ syn_col = syn_col .where (syn_col .isin (uvals ), RARE_BIN )
101+ # embed unique values into high-dimensional space
102+ embedder = load_embedder ()
103+ embeds = embedder .encode (uvals + [RARE_BIN ])
104+ # project embeddings into a low-dimensional space
105+ dims = 2 # potentially adapt to the number of unique values
106+ pca_model = PCA (n_components = dims )
107+ embeds = pca_model .fit_transform (embeds )
108+ # create mapping from unique values to PCA
109+ embeds = pd .DataFrame (embeds )
110+ embeds .index = uvals + [RARE_BIN ]
111+ # map values to PCA
112+ syn_str [col ] = embeds .reindex (syn_col .values ).reset_index (drop = True )
113+ trn_str [col ] = embeds .reindex (trn_col .values ).reset_index (drop = True )
114+ hol_str [col ] = embeds .reindex (hol_col .values ).reset_index (drop = True )
115+ # assign column names
116+ columns = [f"{ col } - PCA { i + 1 } " for i in range (dims )]
117+ syn_str [col ].columns = columns
118+ trn_str [col ].columns = columns
119+ hol_str [col ].columns = columns
120+ syn_str = pd .concat (syn_str .values (), axis = 1 ) if syn_str else pd .DataFrame ()
121+ syn_str .index = syn .index
122+ trn_str = pd .concat (trn_str .values (), axis = 1 ) if trn_str else pd .DataFrame ()
123+ trn_str .index = trn .index
124+ if len (hol ) > 0 :
125+ hol_str = pd .concat (hol_str .values (), axis = 1 ) if hol_str else pd .DataFrame ()
126+ hol_str .index = hol .index
127+ else :
128+ hol_str = None
129+ return syn_str , trn_str , hol_str
130+
131+
132+ def encode_data (
133+ syn : pd .DataFrame , trn : pd .DataFrame , hol : pd .DataFrame | None = None
134+ ) -> tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame | None ]:
135+ """
136+ Encode all columns corresponding to their data type.
137+ """
138+ num_dat_cols = trn .select_dtypes (include = ["number" , "datetime" ]).columns
139+ string_cols = [col for col in trn .columns if col not in num_dat_cols ]
140+ syn_num , trn_num , hol_num = encode_numerics (
141+ syn [num_dat_cols ], trn [num_dat_cols ], hol [num_dat_cols ] if hol is not None else None
142+ )
143+ syn_str , trn_str , hol_str = encode_strings (
144+ syn [string_cols ], trn [string_cols ], hol [string_cols ] if hol is not None else None
145+ )
146+ syn_encoded = pd .concat ([syn_num , syn_str ], axis = 1 )
147+ trn_encoded = pd .concat ([trn_num , trn_str ], axis = 1 )
148+ hol_encoded = pd .concat ([hol_num , hol_str ], axis = 1 ) if hol is not None else None
149+ return syn_encoded , trn_encoded , hol_encoded
150+
151+
31152def calculate_dcrs_nndrs (
32153 data : np .ndarray | None , query : np .ndarray | None
33154) -> tuple [np .ndarray | None , np .ndarray | None ]:
34155 """
35156 Calculate Distance to Closest Records (DCRs) and Nearest Neighbor Distance Ratios (NNDRs).
36-
37- Args:
38- data: Embeddings of the training data.
39- query: Embeddings of the query set.
40-
41- Returns:
42157 """
43- if data is None or query is None :
158+ if data is None or query is None or data . shape [ 0 ] == 0 or query . shape [ 0 ] == 0 :
44159 return None , None
45160 _LOG .info (f"calculate DCRs for { data .shape = } and { query .shape = } " )
46161 t0 = time .time ()
47162 data = data [data [:, 0 ].argsort ()] # sort data by first dimension to enforce deterministic results
163+
48164 if platform .system () == "Linux" :
49165 # use FAISS on Linux for best performance
50166 import faiss # type: ignore
51167
52- index = faiss .IndexFlatIP (data .shape [1 ]) # inner product for cosine similarity with normalized vectors
168+ index = faiss .IndexFlatL2 (data .shape [1 ])
53169 index .add (data )
54- similarities , _ = index .search (query , 2 )
55- dcrs = np .clip ( 1 - similarities , 0 , 1 )
170+ dcrs , _ = index .search (query , 2 )
171+ dcrs = np .sqrt ( dcrs ) # FAISS returns squared distances
56172 else :
57173 # use sklearn as a fallback on non-Linux systems to avoid segfaults; these occurred when using QA as part of SDK
58174 from sklearn .neighbors import NearestNeighbors # type: ignore
59175 from joblib import cpu_count # type: ignore
60176
61- index = NearestNeighbors (
62- n_neighbors = 2 , algorithm = "auto" , metric = "cosine" , n_jobs = min (16 , max (1 , cpu_count () - 1 ))
63- )
177+ index = NearestNeighbors (n_neighbors = 2 , algorithm = "auto" , metric = "l2" , n_jobs = min (16 , max (1 , cpu_count () - 1 )))
64178 index .fit (data )
65179 dcrs , _ = index .kneighbors (query )
66180 dcr = dcrs [:, 0 ]
@@ -70,34 +184,31 @@ def calculate_dcrs_nndrs(
70184
71185
72186def calculate_distances (
73- * , syn_embeds : np .ndarray , trn_embeds : np .ndarray , hol_embeds : np .ndarray | None
187+ * , syn_encoded : np .ndarray , trn_encoded : np .ndarray , hol_encoded : np .ndarray | None
74188) -> dict [str , np .ndarray ]:
75189 """
76190 Calculates distances to the closest records (DCR).
77-
78- Args:
79- syn_embeds: Embeddings of synthetic data.
80- trn_embeds: Embeddings of training data.
81- hol_embeds: Embeddings of holdout data.
82-
83- Returns:
84- Dictionary containing:
85- - dcr_syn_trn: DCR for synthetic to training.
86- - dcr_syn_hol: DCR for synthetic to holdout.
87- - dcr_trn_hol: DCR for training to holdout.
88- - nndr_syn_trn: NNDR for synthetic to training.
89- - nndr_syn_hol: NNDR for synthetic to holdout.
90- - nndr_trn_hol: NNDR for training to holdout.
91191 """
92- if hol_embeds is not None :
93- assert trn_embeds .shape == hol_embeds .shape
192+ assert syn_encoded .shape == trn_encoded .shape
193+ if hol_encoded is not None and hol_encoded .shape [0 ] > 0 :
194+ assert trn_encoded .shape == hol_encoded .shape
195+
196+ # cap dimensionality of encoded data
197+ max_dims = 256
198+ if trn_encoded .shape [1 ] > max_dims :
199+ _LOG .info (f"capping dimensionality of encoded data from { trn_encoded .shape [1 ]} to { max_dims } " )
200+ pca_model = PCA (n_components = max_dims )
201+ pca_model .fit (np .vstack ((trn_encoded , hol_encoded )))
202+ trn_encoded = pca_model .transform (trn_encoded )
203+ hol_encoded = pca_model .transform (hol_encoded )
204+ syn_encoded = pca_model .transform (syn_encoded )
94205
95206 # calculate DCR / NNDR for synthetic to training
96- dcr_syn_trn , nndr_syn_trn = calculate_dcrs_nndrs (data = trn_embeds , query = syn_embeds )
207+ dcr_syn_trn , nndr_syn_trn = calculate_dcrs_nndrs (data = trn_encoded , query = syn_encoded )
97208 # calculate DCR / NNDR for synthetic to holdout
98- dcr_syn_hol , nndr_syn_hol = calculate_dcrs_nndrs (data = hol_embeds , query = syn_embeds )
209+ dcr_syn_hol , nndr_syn_hol = calculate_dcrs_nndrs (data = hol_encoded , query = syn_encoded )
99210 # calculate DCR / NNDR for holdout to training
100- dcr_trn_hol , nndr_trn_hol = calculate_dcrs_nndrs (data = trn_embeds , query = hol_embeds )
211+ dcr_trn_hol , nndr_trn_hol = calculate_dcrs_nndrs (data = trn_encoded , query = hol_encoded )
101212
102213 # log statistics
103214 def deciles (x ):
0 commit comments