Skip to content

Commit 11560d2

Browse files
authored
feat: switch to custom encoding space for distance metrics (#185)
1 parent 3cdd4cd commit 11560d2

File tree

11 files changed

+299
-130
lines changed

11 files changed

+299
-130
lines changed

mostlyai/qa/_distances.py

Lines changed: 144 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,50 +17,164 @@
1717
import time
1818

1919
import numpy as np
20+
import pandas as pd
21+
from sklearn.preprocessing import QuantileTransformer
2022

2123
from mostlyai.qa._common import (
2224
CHARTS_COLORS,
2325
CHARTS_FONTS,
26+
EMPTY_BIN,
27+
NA_BIN,
28+
RARE_BIN,
2429
)
2530
from mostlyai.qa._filesystem import TemporaryWorkspace
2631
from plotly import graph_objs as go
2732

33+
from mostlyai.qa.assets import load_embedder
34+
from sklearn.decomposition import PCA
35+
2836
_LOG = logging.getLogger(__name__)
2937

3038

39+
def encode_numerics(
40+
syn: pd.DataFrame, trn: pd.DataFrame, hol: pd.DataFrame | None = None
41+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame | None]:
42+
"""
43+
Encode numeric features by mapping this via QuantileTransformer to a uniform distribution from [-0.5, 0.5].
44+
"""
45+
syn_num, trn_num, hol_num = {}, {}, {}
46+
if hol is None:
47+
hol = pd.DataFrame(columns=trn.columns)
48+
for col in trn.columns:
49+
# convert to numerics
50+
syn_num[col] = pd.to_numeric(syn[col], errors="coerce")
51+
trn_num[col] = pd.to_numeric(trn[col], errors="coerce")
52+
hol_num[col] = pd.to_numeric(hol[col], errors="coerce")
53+
# retain NAs (needed for datetime)
54+
syn_num[col] = syn_num[col].where(~syn[col].isna(), np.nan)
55+
trn_num[col] = trn_num[col].where(~trn[col].isna(), np.nan)
56+
hol_num[col] = hol_num[col].where(~hol[col].isna(), np.nan)
57+
# normalize numeric features based on trn
58+
qt_scaler = QuantileTransformer(
59+
output_distribution="uniform",
60+
random_state=42,
61+
n_quantiles=min(100, len(trn) + len(hol)),
62+
)
63+
ori_num = pd.concat([trn_num[col], hol_num[col]]) if len(hol) > 0 else pd.DataFrame(trn_num[col])
64+
qt_scaler.fit(ori_num.values.reshape(-1, 1))
65+
syn_num[col] = qt_scaler.transform(syn_num[col].values.reshape(-1, 1))[:, 0] - 0.5
66+
trn_num[col] = qt_scaler.transform(trn_num[col].values.reshape(-1, 1))[:, 0] - 0.5
67+
hol_num[col] = qt_scaler.transform(hol_num[col].values.reshape(-1, 1))[:, 0] - 0.5 if len(hol) > 0 else None
68+
# replace NAs with 0.0
69+
syn_num[col] = np.nan_to_num(syn_num[col], nan=0.0)
70+
trn_num[col] = np.nan_to_num(trn_num[col], nan=0.0)
71+
hol_num[col] = np.nan_to_num(hol_num[col], nan=0.0)
72+
# add extra columns for NAs
73+
if trn[col].isna().any() or hol[col].isna().any():
74+
syn_num[col + " - N/A"] = syn[col].isna().astype(float)
75+
trn_num[col + " - N/A"] = trn[col].isna().astype(float)
76+
hol_num[col + " - N/A"] = hol[col].isna().astype(float)
77+
syn_num = pd.DataFrame(syn_num, index=syn.index)
78+
trn_num = pd.DataFrame(trn_num, index=trn.index)
79+
hol_num = pd.DataFrame(hol_num, index=hol.index) if len(hol) > 0 else None
80+
return syn_num, trn_num, hol_num
81+
82+
83+
def encode_strings(
84+
syn: pd.DataFrame, trn: pd.DataFrame, hol: pd.DataFrame | None = None
85+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame | None]:
86+
"""
87+
Encode string features by mapping them to a low-dimensional space using PCA of their embeddings.
88+
"""
89+
trn_str, syn_str, hol_str = {}, {}, {}
90+
if hol is None:
91+
hol = pd.DataFrame(columns=trn.columns)
92+
for col in trn.columns:
93+
# prepare inputs
94+
syn_col = syn[col].astype(str).fillna(NA_BIN).replace("", EMPTY_BIN)
95+
trn_col = trn[col].astype(str).fillna(NA_BIN).replace("", EMPTY_BIN)
96+
hol_col = hol[col].astype(str).fillna(NA_BIN).replace("", EMPTY_BIN)
97+
# get unique original values
98+
uvals = pd.concat([trn_col, hol_col]).value_counts().index.to_list()
99+
# map out of range values to RARE_BIN
100+
syn_col = syn_col.where(syn_col.isin(uvals), RARE_BIN)
101+
# embed unique values into high-dimensional space
102+
embedder = load_embedder()
103+
embeds = embedder.encode(uvals + [RARE_BIN])
104+
# project embeddings into a low-dimensional space
105+
dims = 2 # potentially adapt to the number of unique values
106+
pca_model = PCA(n_components=dims)
107+
embeds = pca_model.fit_transform(embeds)
108+
# create mapping from unique values to PCA
109+
embeds = pd.DataFrame(embeds)
110+
embeds.index = uvals + [RARE_BIN]
111+
# map values to PCA
112+
syn_str[col] = embeds.reindex(syn_col.values).reset_index(drop=True)
113+
trn_str[col] = embeds.reindex(trn_col.values).reset_index(drop=True)
114+
hol_str[col] = embeds.reindex(hol_col.values).reset_index(drop=True)
115+
# assign column names
116+
columns = [f"{col} - PCA {i + 1}" for i in range(dims)]
117+
syn_str[col].columns = columns
118+
trn_str[col].columns = columns
119+
hol_str[col].columns = columns
120+
syn_str = pd.concat(syn_str.values(), axis=1) if syn_str else pd.DataFrame()
121+
syn_str.index = syn.index
122+
trn_str = pd.concat(trn_str.values(), axis=1) if trn_str else pd.DataFrame()
123+
trn_str.index = trn.index
124+
if len(hol) > 0:
125+
hol_str = pd.concat(hol_str.values(), axis=1) if hol_str else pd.DataFrame()
126+
hol_str.index = hol.index
127+
else:
128+
hol_str = None
129+
return syn_str, trn_str, hol_str
130+
131+
132+
def encode_data(
133+
syn: pd.DataFrame, trn: pd.DataFrame, hol: pd.DataFrame | None = None
134+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame | None]:
135+
"""
136+
Encode all columns corresponding to their data type.
137+
"""
138+
num_dat_cols = trn.select_dtypes(include=["number", "datetime"]).columns
139+
string_cols = [col for col in trn.columns if col not in num_dat_cols]
140+
syn_num, trn_num, hol_num = encode_numerics(
141+
syn[num_dat_cols], trn[num_dat_cols], hol[num_dat_cols] if hol is not None else None
142+
)
143+
syn_str, trn_str, hol_str = encode_strings(
144+
syn[string_cols], trn[string_cols], hol[string_cols] if hol is not None else None
145+
)
146+
syn_encoded = pd.concat([syn_num, syn_str], axis=1)
147+
trn_encoded = pd.concat([trn_num, trn_str], axis=1)
148+
hol_encoded = pd.concat([hol_num, hol_str], axis=1) if hol is not None else None
149+
return syn_encoded, trn_encoded, hol_encoded
150+
151+
31152
def calculate_dcrs_nndrs(
32153
data: np.ndarray | None, query: np.ndarray | None
33154
) -> tuple[np.ndarray | None, np.ndarray | None]:
34155
"""
35156
Calculate Distance to Closest Records (DCRs) and Nearest Neighbor Distance Ratios (NNDRs).
36-
37-
Args:
38-
data: Embeddings of the training data.
39-
query: Embeddings of the query set.
40-
41-
Returns:
42157
"""
43-
if data is None or query is None:
158+
if data is None or query is None or data.shape[0] == 0 or query.shape[0] == 0:
44159
return None, None
45160
_LOG.info(f"calculate DCRs for {data.shape=} and {query.shape=}")
46161
t0 = time.time()
47162
data = data[data[:, 0].argsort()] # sort data by first dimension to enforce deterministic results
163+
48164
if platform.system() == "Linux":
49165
# use FAISS on Linux for best performance
50166
import faiss # type: ignore
51167

52-
index = faiss.IndexFlatIP(data.shape[1]) # inner product for cosine similarity with normalized vectors
168+
index = faiss.IndexFlatL2(data.shape[1])
53169
index.add(data)
54-
similarities, _ = index.search(query, 2)
55-
dcrs = np.clip(1 - similarities, 0, 1)
170+
dcrs, _ = index.search(query, 2)
171+
dcrs = np.sqrt(dcrs) # FAISS returns squared distances
56172
else:
57173
# use sklearn as a fallback on non-Linux systems to avoid segfaults; these occurred when using QA as part of SDK
58174
from sklearn.neighbors import NearestNeighbors # type: ignore
59175
from joblib import cpu_count # type: ignore
60176

61-
index = NearestNeighbors(
62-
n_neighbors=2, algorithm="auto", metric="cosine", n_jobs=min(16, max(1, cpu_count() - 1))
63-
)
177+
index = NearestNeighbors(n_neighbors=2, algorithm="auto", metric="l2", n_jobs=min(16, max(1, cpu_count() - 1)))
64178
index.fit(data)
65179
dcrs, _ = index.kneighbors(query)
66180
dcr = dcrs[:, 0]
@@ -70,34 +184,31 @@ def calculate_dcrs_nndrs(
70184

71185

72186
def calculate_distances(
73-
*, syn_embeds: np.ndarray, trn_embeds: np.ndarray, hol_embeds: np.ndarray | None
187+
*, syn_encoded: np.ndarray, trn_encoded: np.ndarray, hol_encoded: np.ndarray | None
74188
) -> dict[str, np.ndarray]:
75189
"""
76190
Calculates distances to the closest records (DCR).
77-
78-
Args:
79-
syn_embeds: Embeddings of synthetic data.
80-
trn_embeds: Embeddings of training data.
81-
hol_embeds: Embeddings of holdout data.
82-
83-
Returns:
84-
Dictionary containing:
85-
- dcr_syn_trn: DCR for synthetic to training.
86-
- dcr_syn_hol: DCR for synthetic to holdout.
87-
- dcr_trn_hol: DCR for training to holdout.
88-
- nndr_syn_trn: NNDR for synthetic to training.
89-
- nndr_syn_hol: NNDR for synthetic to holdout.
90-
- nndr_trn_hol: NNDR for training to holdout.
91191
"""
92-
if hol_embeds is not None:
93-
assert trn_embeds.shape == hol_embeds.shape
192+
assert syn_encoded.shape == trn_encoded.shape
193+
if hol_encoded is not None and hol_encoded.shape[0] > 0:
194+
assert trn_encoded.shape == hol_encoded.shape
195+
196+
# cap dimensionality of encoded data
197+
max_dims = 256
198+
if trn_encoded.shape[1] > max_dims:
199+
_LOG.info(f"capping dimensionality of encoded data from {trn_encoded.shape[1]} to {max_dims}")
200+
pca_model = PCA(n_components=max_dims)
201+
pca_model.fit(np.vstack((trn_encoded, hol_encoded)))
202+
trn_encoded = pca_model.transform(trn_encoded)
203+
hol_encoded = pca_model.transform(hol_encoded)
204+
syn_encoded = pca_model.transform(syn_encoded)
94205

95206
# calculate DCR / NNDR for synthetic to training
96-
dcr_syn_trn, nndr_syn_trn = calculate_dcrs_nndrs(data=trn_embeds, query=syn_embeds)
207+
dcr_syn_trn, nndr_syn_trn = calculate_dcrs_nndrs(data=trn_encoded, query=syn_encoded)
97208
# calculate DCR / NNDR for synthetic to holdout
98-
dcr_syn_hol, nndr_syn_hol = calculate_dcrs_nndrs(data=hol_embeds, query=syn_embeds)
209+
dcr_syn_hol, nndr_syn_hol = calculate_dcrs_nndrs(data=hol_encoded, query=syn_encoded)
99210
# calculate DCR / NNDR for holdout to training
100-
dcr_trn_hol, nndr_trn_hol = calculate_dcrs_nndrs(data=trn_embeds, query=hol_embeds)
211+
dcr_trn_hol, nndr_trn_hol = calculate_dcrs_nndrs(data=trn_encoded, query=hol_encoded)
101212

102213
# log statistics
103214
def deciles(x):

mostlyai/qa/_sampling.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@
2525
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2626
# See the License for the specific language governing permissions and
2727
# limitations under the License.
28-
import datetime
2928
import logging
3029
import random
3130
import time
32-
import string
3331
import xxhash
3432
from typing import Any
3533
from pandas.core.dtypes.common import is_numeric_dtype, is_datetime64_dtype
@@ -230,7 +228,6 @@ def pull_data_for_embeddings(
230228
ctx_primary_key: str | None = None,
231229
tgt_context_key: str | None = None,
232230
max_sample_size: int | None = None,
233-
bins: dict[str, list] | None = None,
234231
) -> list[str]:
235232
_LOG.info("pulling data for embeddings")
236233
t0 = time.time()
@@ -265,20 +262,6 @@ def pull_data_for_embeddings(
265262
df_tgt = df_tgt.rename(columns={tgt_context_key: key})
266263
tgt_context_key = key
267264

268-
# bin columns; also to prevent distortion of embeddings by adding extra precision or unknown values
269-
bins = bins or {}
270-
df_tgt.columns = [TGT_COLUMN_PREFIX + c if c != key else c for c in df_tgt.columns]
271-
df_tgt, _ = bin_data(df_tgt, bins=bins, non_categorical_label_style="lower")
272-
# add some prefix to make numeric and date values unique in the embedding space
273-
for col in df_tgt.columns:
274-
if col in bins:
275-
if isinstance(
276-
bins[col][0], (int, float, np.integer, np.floating, datetime.date, datetime.datetime, np.datetime64)
277-
):
278-
prefixes = string.ascii_lowercase + string.ascii_uppercase
279-
prefix = prefixes[xxhash.xxh32_intdigest(col) % len(prefixes)]
280-
df_tgt[col] = prefix + df_tgt[col].astype(str)
281-
282265
# split into chunks while keeping groups together and process in parallel
283266
n_jobs = min(16, max(1, cpu_count() - 1))
284267
hash_ids = df_tgt[tgt_context_key].apply(lambda x: xxhash.xxh32_intdigest(str(x))) % n_jobs

mostlyai/qa/assets/html/head.html

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
font-size: normal;
2121
color: var(--muted-color);
2222
}
23+
.ref-metric {
24+
color: var(--muted-color);
25+
margin-top: -2px;
26+
font-size: 0.8em;
27+
height: 24px;
28+
font-weight: normal;
29+
}
2330
</style>
2431
<script>{{ html_assets['bootstrap-5.3.3.bundle.min.js'] }}</script>
2532
<script>{{ html_assets['plotly-3.0.1.min.js'] }}</script>

0 commit comments

Comments
 (0)