Skip to content

Commit 12c9572

Browse files
authored
feat: bin all columns for embeddings (#183)
1 parent 2f13610 commit 12c9572

File tree

7 files changed

+50
-84
lines changed

7 files changed

+50
-84
lines changed

mostlyai/qa/_accuracy.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import datetime
1516
import functools
1617
import hashlib
1718
import logging
@@ -1034,7 +1035,7 @@ def binning_data(
10341035
def bin_data(
10351036
df: pd.DataFrame,
10361037
bins: int | dict[str, list],
1037-
non_categorical_label_style: Literal["bounded", "unbounded"] = "unbounded",
1038+
non_categorical_label_style: Literal["bounded", "unbounded", "lower"] = "unbounded",
10381039
) -> tuple[pd.DataFrame, dict[str, list]]:
10391040
"""
10401041
Splits data into bins.
@@ -1048,41 +1049,32 @@ def bin_data(
10481049

10491050
# Note, that we create a new pd.DataFrame to avoid fragmentation warning messages that can occur if we try to
10501051
# replace hundreds of columns of a large dataset
1051-
cols = {}
1052-
1053-
bins_dct = {}
1054-
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
1055-
dat_cols = [c for c in df.columns if pd.api.types.is_datetime64_any_dtype(df[c])]
1056-
cat_cols = [c for c in df.columns if c not in num_cols + dat_cols]
1052+
cols, bins_dct = {}, {}
10571053
if isinstance(bins, int):
1058-
for col in num_cols:
1059-
cols[col], bins_dct[col] = bin_numeric(df[col], bins, label_style=non_categorical_label_style)
1060-
for col in dat_cols:
1061-
cols[col], bins_dct[col] = bin_datetime(df[col], bins, label_style=non_categorical_label_style)
1062-
for col in cat_cols:
1063-
cols[col], bins_dct[col] = bin_categorical(df[col], bins)
1064-
else: # bins is a dict
1065-
for col in num_cols:
1066-
if col in bins:
1067-
cols[col], _ = bin_numeric(df[col], bins[col], label_style=non_categorical_label_style)
1054+
for col in df.columns:
1055+
if pd.api.types.is_numeric_dtype(df[col]):
1056+
cols[col], bins_dct[col] = bin_numeric(df[col], bins, label_style=non_categorical_label_style)
1057+
elif pd.api.types.is_datetime64_any_dtype(df[col]):
1058+
cols[col], bins_dct[col] = bin_datetime(df[col], bins, label_style=non_categorical_label_style)
10681059
else:
1069-
_LOG.warning(f"'{col}' is missing in bins")
1070-
for col in dat_cols:
1071-
if col in bins:
1072-
cols[col], _ = bin_datetime(df[col], bins[col], label_style=non_categorical_label_style)
1073-
else:
1074-
_LOG.warning(f"'{col}' is missing in bins")
1075-
for col in cat_cols:
1060+
cols[col], bins_dct[col] = bin_categorical(df[col], bins)
1061+
else: # bins is a dict
1062+
for col in df.columns:
10761063
if col in bins:
1077-
cols[col], _ = bin_categorical(df[col], bins[col])
1064+
if isinstance(bins[col][0], (int, float, np.integer, np.floating)):
1065+
cols[col], _ = bin_numeric(df[col], bins[col], label_style=non_categorical_label_style)
1066+
elif isinstance(bins[col][0], (datetime.date, datetime.datetime, np.datetime64)):
1067+
cols[col], _ = bin_datetime(df[col], bins[col], label_style=non_categorical_label_style)
1068+
else:
1069+
cols[col], _ = bin_categorical(df[col], bins[col])
10781070
else:
1079-
_LOG.warning(f"'{col}' is missing in bins")
1071+
cols[col] = df[col]
10801072
bins_dct = bins
10811073
return pd.DataFrame(cols), bins_dct
10821074

10831075

10841076
def bin_numeric(
1085-
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded"] = "unbounded"
1077+
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded", "lower"] = "unbounded"
10861078
) -> tuple[pd.Categorical, list]:
10871079
def _clip(col, bins):
10881080
if isinstance(bins, list):
@@ -1131,7 +1123,7 @@ def _adjust_breaks(breaks):
11311123

11321124

11331125
def bin_datetime(
1134-
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded"] = "unbounded"
1126+
col: pd.Series, bins: int | list[str], label_style: Literal["bounded", "unbounded", "lower"] = "unbounded"
11351127
) -> tuple[pd.Categorical, list]:
11361128
def _clip(col, bins):
11371129
if isinstance(bins, list):
@@ -1184,7 +1176,7 @@ def bin_non_categorical(
11841176
clip_and_breaks: Callable,
11851177
create_labels: Callable,
11861178
adjust_breaks: Callable,
1187-
label_style: Literal["bounded", "unbounded"] = "unbounded",
1179+
label_style: Literal["bounded", "unbounded", "lower"] = "unbounded",
11881180
) -> tuple[pd.Categorical, list]:
11891181
col = col.fillna(np.nan).infer_objects(copy=False)
11901182

@@ -1203,7 +1195,9 @@ def bin_non_categorical(
12031195
)
12041196
labels = [str(b) for b in breaks[:-1]]
12051197

1206-
if label_style == "unbounded":
1198+
if label_style == "lower":
1199+
new_labels_map = {label: f"{label}" for label in labels}
1200+
elif label_style == "unbounded":
12071201
new_labels_map = {label: f"⪰ {label}" for label in labels}
12081202
else: # label_style == "bounded"
12091203
new_labels_map = {label: f"⪰ {label}{next_label}" for label, next_label in zip(labels, labels[1:] + ["∞"])}

mostlyai/qa/_sampling.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2626
# See the License for the specific language governing permissions and
2727
# limitations under the License.
28+
import datetime
2829
import logging
2930
import random
3031
import time
@@ -229,7 +230,7 @@ def pull_data_for_embeddings(
229230
ctx_primary_key: str | None = None,
230231
tgt_context_key: str | None = None,
231232
max_sample_size: int | None = None,
232-
tgt_num_dat_bins: dict[str, list] | None = None,
233+
bins: dict[str, list] | None = None,
233234
) -> list[str]:
234235
_LOG.info("pulling data for embeddings")
235236
t0 = time.time()
@@ -264,12 +265,19 @@ def pull_data_for_embeddings(
264265
df_tgt = df_tgt.rename(columns={tgt_context_key: key})
265266
tgt_context_key = key
266267

267-
# bin numeric and datetime columns; partly also to prevent
268-
# embedding distortion by adding extra precision to values
269-
prefixes = string.ascii_lowercase + string.ascii_uppercase
270-
tgt_num_dat_bins = tgt_num_dat_bins or {}
271-
for i, col in enumerate(tgt_num_dat_bins.keys()):
272-
df_tgt[col] = bin_num_dat(values=df_tgt[col], bins=tgt_num_dat_bins[col], prefix=prefixes[i % len(prefixes)])
268+
# bin columns; also to prevent distortion of embeddings by adding extra precision or unknown values
269+
bins = bins or {}
270+
df_tgt.columns = [TGT_COLUMN_PREFIX + c if c != key else c for c in df_tgt.columns]
271+
df_tgt, _ = bin_data(df_tgt, bins=bins, non_categorical_label_style="lower")
272+
# add some prefix to make numeric and date values unique in the embedding space
273+
for col in df_tgt.columns:
274+
if col in bins:
275+
if isinstance(
276+
bins[col][0], (int, float, np.integer, np.floating, datetime.date, datetime.datetime, np.datetime64)
277+
):
278+
prefixes = string.ascii_lowercase + string.ascii_uppercase
279+
prefix = prefixes[xxhash.xxh32_intdigest(col) % len(prefixes)]
280+
df_tgt[col] = prefix + df_tgt[col].astype(str)
273281

274282
# split into chunks while keeping groups together and process in parallel
275283
n_jobs = min(16, max(1, cpu_count() - 1))
@@ -303,15 +311,6 @@ def sequence_to_string(sequence: pd.DataFrame) -> str:
303311
return strings
304312

305313

306-
def bin_num_dat(values: pd.Series, bins: list, prefix: str) -> pd.Series:
307-
bins = sorted(set(bins))
308-
binned = pd.cut(values, bins=bins, labels=bins[:-1], include_lowest=True).astype(str)
309-
binned[values <= min(bins)] = str(bins[0])
310-
binned[values >= max(bins)] = str(bins[-1])
311-
binned[values.isna()] = "NA"
312-
return prefix + binned
313-
314-
315314
def calculate_embeddings(
316315
strings: list[str],
317316
progress: ProgressCallbackWrapper | None = None,

mostlyai/qa/reporting.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import datetime
1615
import logging
1716
import warnings
1817
from pathlib import Path
@@ -289,12 +288,6 @@ def report(
289288
embedder = load_embedder()
290289
_LOG.info("load tgt bins")
291290
bins = statistics.load_bins()
292-
tgt_num_dat_bins = {
293-
c.replace(TGT_COLUMN_PREFIX, ""): bins[c]
294-
for c in bins.keys()
295-
if c.replace(TGT_COLUMN_PREFIX, "") in trn_tgt_data.columns
296-
and isinstance(bins[c][0], (int, float, datetime.date, datetime.datetime))
297-
}
298291

299292
_LOG.info("calculate embeddings for synthetic")
300293
syn_embeds = calculate_embeddings(
@@ -304,7 +297,7 @@ def report(
304297
ctx_primary_key=ctx_primary_key,
305298
tgt_context_key=tgt_context_key,
306299
max_sample_size=max_sample_size_embeddings_final,
307-
tgt_num_dat_bins=tgt_num_dat_bins,
300+
bins=bins,
308301
),
309302
progress=progress,
310303
progress_from=25,
@@ -319,7 +312,7 @@ def report(
319312
ctx_primary_key=ctx_primary_key,
320313
tgt_context_key=tgt_context_key,
321314
max_sample_size=max_sample_size_embeddings_final,
322-
tgt_num_dat_bins=tgt_num_dat_bins,
315+
bins=bins,
323316
),
324317
progress=progress,
325318
progress_from=45,
@@ -335,7 +328,7 @@ def report(
335328
ctx_primary_key=ctx_primary_key,
336329
tgt_context_key=tgt_context_key,
337330
max_sample_size=max_sample_size_embeddings_final,
338-
tgt_num_dat_bins=tgt_num_dat_bins,
331+
bins=bins,
339332
),
340333
progress=progress,
341334
progress_from=65,

mostlyai/qa/reporting_from_statistics.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import datetime
1615
import logging
1716
from pathlib import Path
1817

@@ -28,7 +27,6 @@
2827
)
2928
from mostlyai.qa._sampling import pull_data_for_embeddings, calculate_embeddings, pull_data_for_coherence
3029
from mostlyai.qa._common import (
31-
TGT_COLUMN_PREFIX,
3230
ProgressCallback,
3331
PrerequisiteNotMetError,
3432
check_min_sample_size,
@@ -169,12 +167,6 @@ def report_from_statistics(
169167
embedder = load_embedder()
170168
_LOG.info("load bins")
171169
bins = statistics.load_bins()
172-
tgt_num_dat_bins = {
173-
c.replace(TGT_COLUMN_PREFIX, ""): bins[c]
174-
for c in bins.keys()
175-
if c.replace(TGT_COLUMN_PREFIX, "") in syn_tgt_data.columns
176-
and isinstance(bins[c][0], (int, float, datetime.date, datetime.datetime))
177-
}
178170

179171
_LOG.info("calculate embeddings for synthetic")
180172
syn_embeds = calculate_embeddings(
@@ -184,7 +176,7 @@ def report_from_statistics(
184176
ctx_primary_key=ctx_primary_key,
185177
tgt_context_key=tgt_context_key,
186178
max_sample_size=max_sample_size_embeddings,
187-
tgt_num_dat_bins=tgt_num_dat_bins,
179+
bins=bins,
188180
),
189181
progress=progress,
190182
progress_from=40,

tests/end_to_end/test_report.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,6 @@ def test_report_flat_rare(tmp_path):
176176
assert metrics.accuracy.univariate == 0.0
177177
assert metrics.distances.ims_training == metrics.distances.ims_holdout == 0.0
178178

179-
# test case where rare values are not protected, and we leak trn into synthetic
180-
syn_tgt_data = pd.DataFrame({"x": trn_tgt_data["x"].sample(100, replace=True)})
181-
_, metrics = qa.report(
182-
syn_tgt_data=syn_tgt_data,
183-
trn_tgt_data=trn_tgt_data,
184-
hol_tgt_data=hol_tgt_data,
185-
statistics_path=statistics_path,
186-
)
187-
assert metrics.distances.ims_training > metrics.distances.ims_holdout
188-
assert metrics.distances.dcr_training < metrics.distances.dcr_holdout
189-
190179

191180
def test_report_flat_early_exit(tmp_path):
192181
# test early exit for dfs with <100 rows

tests/unit/test_accuracy.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,12 +615,11 @@ def test_trim_labels():
615615

616616

617617
def test_calculate_correlations(cols):
618-
trn, hol, syn = cols
619-
trn, bins = bin_data(trn, 3)
620-
syn, _ = bin_data(syn, bins)
621-
# prefix some columns with "tgt::"
622-
columns = [f"tgt::{c}" if c != "cat" else c for idx, c in enumerate(trn.columns)]
623-
trn.columns, syn.columns = columns, columns
618+
trn, _, syn = cols
619+
trn, bins = bin_data(trn[["num", "dt"]], 3)
620+
syn, _ = bin_data(syn[["num", "dt"]], bins)
621+
trn = trn.add_prefix("tgt::")
622+
syn = syn.add_prefix("tgt::")
624623
corr_trn = calculate_correlations(trn)
625624
exp_corr_trn = pd.DataFrame(
626625
[

tests/unit/test_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_pull_data_for_embeddings_large_int(tmp_path):
3333
{"cc": list(np.random.randint(100, 200, size=1000)) + [1800218404984585216] + [pd.NA]}, dtype="Int64"
3434
)
3535
bins = {"cc": [100, 200]}
36-
pull_data_for_embeddings(df_tgt=df, tgt_num_dat_bins=bins)
36+
pull_data_for_embeddings(df_tgt=df, bins=bins)
3737

3838

3939
def test_pull_data_for_embeddings_dates(tmp_path):
@@ -48,4 +48,4 @@ def test_pull_data_for_embeddings_dates(tmp_path):
4848
"y": [datetime(2020, 2, 1), datetime(2024, 1, 1)],
4949
"z": [datetime(2020, 2, 1), datetime(2024, 1, 1)],
5050
}
51-
pull_data_for_embeddings(df_tgt=df, tgt_num_dat_bins=bins)
51+
pull_data_for_embeddings(df_tgt=df, bins=bins)

0 commit comments

Comments
 (0)