Skip to content

Commit 0730089

Browse files
authored
feat: add ims_trn_hol + dcr_trn_hol as new metrics; add dcr_trn_hol + dcr_share to HTML (#123)
1 parent 0c61a80 commit 0730089

File tree

7 files changed

+218
-92
lines changed

7 files changed

+218
-92
lines changed

examples/benchmark.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
"source": [
112112
"import matplotlib.pyplot as plt\n",
113113
"\n",
114+
"\n",
114115
"def plot_dataset(df, dataset):\n",
115116
" # Define the color mapping for each synthesizer\n",
116117
" color_mapping = {\n",

mostlyai/qa/_distances.py

Lines changed: 146 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,96 +30,186 @@
3030

3131
def calculate_distances(
3232
*, syn_embeds: np.ndarray, trn_embeds: np.ndarray, hol_embeds: np.ndarray | None
33-
) -> tuple[np.ndarray, np.ndarray | None]:
34-
"""
35-
Calculates distances to the closest records (DCR). Once for synthetic to training, and once for synthetic to
36-
holdout data.
33+
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
3734
"""
35+
Calculates distances to the closest records (DCR).
36+
37+
Args:
38+
syn_embeds: Embeddings of synthetic data.
39+
trn_embeds: Embeddings of training data.
40+
hol_embeds: Embeddings of holdout data.
3841
42+
Returns:
43+
Tuple containing:
44+
- dcr_syn_trn: DCR for synthetic to training.
45+
- dcr_syn_hol: DCR for synthetic to holdout.
46+
- dcr_trn_hol: DCR for training to holdout.
47+
"""
3948
if hol_embeds is not None:
4049
assert trn_embeds.shape == hol_embeds.shape
41-
# calculate DCR using L2 metric
42-
index = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16))
43-
index.fit(syn_embeds)
50+
# calculate DCR for synthetic to training
51+
index_syn = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16))
52+
index_syn.fit(syn_embeds)
4453
_LOG.info(f"calculate DCRs for {len(syn_embeds):,} synthetic to {len(trn_embeds):,} training")
45-
dcrs_trn, _ = index.kneighbors(trn_embeds)
46-
dcr_trn = dcrs_trn[:, 0]
54+
dcrs_syn_trn, _ = index_syn.kneighbors(trn_embeds)
55+
dcr_syn_trn = dcrs_syn_trn[:, 0]
56+
57+
dcr_syn_hol = None
58+
dcr_trn_hol = None
59+
4760
if hol_embeds is not None:
61+
# calculate DCR for synthetic to holdout
4862
_LOG.info(f"calculate DCRs for {len(syn_embeds):,} synthetic to {len(hol_embeds):,} holdout")
49-
dcrs_hol, _ = index.kneighbors(hol_embeds)
50-
dcr_hol = dcrs_hol[:, 0]
51-
else:
52-
dcr_hol = None
53-
dcr_trn_deciles = np.round(np.quantile(dcr_trn, np.linspace(0, 1, 11)), 3)
54-
_LOG.info(f"DCR deciles for synthetic to training: {dcr_trn_deciles}")
55-
if dcr_hol is not None:
56-
dcr_hol_deciles = np.round(np.quantile(dcr_hol, np.linspace(0, 1, 11)), 3)
57-
_LOG.info(f"DCR deciles for synthetic to holdout: {dcr_hol_deciles}")
58-
# calculate share of dcr_trn != dcr_hol
59-
_LOG.info(f"share of dcr_trn < dcr_hol: {np.mean(dcr_trn < dcr_hol):.1%}")
60-
_LOG.info(f"share of dcr_trn > dcr_hol: {np.mean(dcr_trn > dcr_hol):.1%}")
61-
return dcr_trn, dcr_hol
62-
63-
64-
def plot_distances(plot_title: str, dcr_trn: np.ndarray, dcr_hol: np.ndarray | None) -> go.Figure:
63+
dcrs_syn_hol, _ = index_syn.kneighbors(hol_embeds)
64+
dcr_syn_hol = dcrs_syn_hol[:, 0]
65+
66+
# calculate DCR for training to holdout
67+
_LOG.info(f"calculate DCRs for {len(trn_embeds):,} training to {len(hol_embeds):,} holdout")
68+
index_trn = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16))
69+
index_trn.fit(trn_embeds)
70+
dcrs_trn_hol, _ = index_trn.kneighbors(hol_embeds)
71+
dcr_trn_hol = dcrs_trn_hol[:, 0]
72+
73+
dcr_syn_trn_deciles = np.round(np.quantile(dcr_syn_trn, np.linspace(0, 1, 11)), 3)
74+
_LOG.info(f"DCR deciles for synthetic to training: {dcr_syn_trn_deciles}")
75+
if dcr_syn_hol is not None:
76+
dcr_syn_hol_deciles = np.round(np.quantile(dcr_syn_hol, np.linspace(0, 1, 11)), 3)
77+
_LOG.info(f"DCR deciles for synthetic to holdout: {dcr_syn_hol_deciles}")
78+
# calculate share of dcr_syn_trn != dcr_syn_hol
79+
_LOG.info(f"share of dcr_syn_trn < dcr_syn_hol: {np.mean(dcr_syn_trn < dcr_syn_hol):.1%}")
80+
_LOG.info(f"share of dcr_syn_trn > dcr_syn_hol: {np.mean(dcr_syn_trn > dcr_syn_hol):.1%}")
81+
82+
if dcr_trn_hol is not None:
83+
dcr_trn_hol_deciles = np.round(np.quantile(dcr_trn_hol, np.linspace(0, 1, 11)), 3)
84+
_LOG.info(f"DCR deciles for training to holdout: {dcr_trn_hol_deciles}")
85+
86+
return dcr_syn_trn, dcr_syn_hol, dcr_trn_hol
87+
88+
89+
def plot_distances(
90+
plot_title: str, dcr_syn_trn: np.ndarray, dcr_syn_hol: np.ndarray | None, dcr_trn_hol: np.ndarray | None
91+
) -> go.Figure:
6592
# calculate quantiles
6693
y = np.linspace(0, 1, 101)
67-
x_trn = np.quantile(dcr_trn, y)
68-
if dcr_hol is not None:
69-
x_hol = np.quantile(dcr_hol, y)
94+
x_syn_trn = np.quantile(dcr_syn_trn, y)
95+
if dcr_syn_hol is not None:
96+
x_syn_hol = np.quantile(dcr_syn_hol, y)
97+
else:
98+
x_syn_hol = None
99+
100+
if dcr_trn_hol is not None:
101+
x_trn_hol = np.quantile(dcr_trn_hol, y)
70102
else:
71-
x_hol = None
103+
x_trn_hol = None
104+
72105
# prepare layout
73106
layout = go.Layout(
74107
title=dict(text=f"<b>{plot_title}</b>", x=0.5, y=0.98),
75108
title_font=CHARTS_FONTS["title"],
76109
font=CHARTS_FONTS["base"],
77-
hoverlabel=CHARTS_FONTS["hover"],
110+
hoverlabel=dict(
111+
**CHARTS_FONTS["hover"],
112+
namelength=-1, # Show full length of hover labels
113+
),
78114
plot_bgcolor=CHARTS_COLORS["background"],
79115
autosize=True,
80116
height=500,
81117
margin=dict(l=20, r=20, b=20, t=40, pad=5),
82-
showlegend=False,
83-
hovermode="x unified",
118+
showlegend=True,
84119
yaxis=dict(
85120
showticklabels=False,
86121
zeroline=True,
87122
zerolinewidth=1,
88123
zerolinecolor="#999999",
89124
rangemode="tozero",
125+
showline=True,
126+
linewidth=1,
127+
linecolor="#999999",
128+
),
129+
yaxis2=dict(
130+
overlaying="y",
131+
side="right",
132+
tickformat=".0%",
133+
showgrid=False,
134+
range=[0, 1],
135+
showline=True,
136+
linewidth=1,
137+
linecolor="#999999",
138+
),
139+
xaxis=dict(
140+
showline=True,
141+
linewidth=1,
142+
linecolor="#999999",
143+
hoverformat=".3f",
90144
),
91145
)
92-
fig = go.Figure(layout=layout).set_subplots(
93-
rows=1,
94-
cols=1,
95-
)
96-
# plot content
97-
cum_trn_scatter = go.Scatter(
98-
mode="lines",
99-
x=x_trn,
100-
y=y,
101-
name="DCR training",
102-
line=dict(color=CHARTS_COLORS["synthetic"], width=5),
103-
yhoverformat=".0%",
104-
)
105-
fig.add_trace(cum_trn_scatter, row=1, col=1)
106-
if x_hol is not None:
107-
cum_hol_scatter = go.Scatter(
146+
fig = go.Figure(layout=layout)
147+
148+
traces = []
149+
150+
# training vs holdout (light gray)
151+
if x_trn_hol is not None:
152+
traces.append(
153+
go.Scatter(
154+
mode="lines",
155+
x=x_trn_hol,
156+
y=y,
157+
name="Training vs. Holdout Data",
158+
line=dict(color="#999999", width=5),
159+
yaxis="y2",
160+
)
161+
)
162+
163+
# synthetic vs holdout (gray)
164+
if x_syn_hol is not None:
165+
traces.append(
166+
go.Scatter(
167+
mode="lines",
168+
x=x_syn_hol,
169+
y=y,
170+
name="Synthetic vs. Holdout Data",
171+
line=dict(color="#666666", width=5),
172+
yaxis="y2",
173+
)
174+
)
175+
176+
# synthetic vs training (green)
177+
traces.append(
178+
go.Scatter(
108179
mode="lines",
109-
x=x_hol,
180+
x=x_syn_trn,
110181
y=y,
111-
name="DCR holdout",
112-
line=dict(color=CHARTS_COLORS["original"], width=5),
113-
yhoverformat=".0%",
182+
name="Synthetic vs. Training Data",
183+
line=dict(color="#24db96", width=5),
184+
yaxis="y2",
185+
)
186+
)
187+
188+
for trace in traces:
189+
fig.add_trace(trace)
190+
191+
fig.update_layout(
192+
legend=dict(
193+
orientation="h",
194+
yanchor="bottom",
195+
y=-0.15,
196+
xanchor="center",
197+
x=0.5,
198+
font=dict(size=10),
199+
traceorder="reversed",
114200
)
115-
fig.add_trace(cum_hol_scatter, row=1, col=1)
201+
)
202+
116203
return fig
117204

118205

119206
def plot_store_distances(
120-
dcr_trn: np.ndarray,
121-
dcr_hol: np.ndarray | None,
207+
dcr_syn_trn: np.ndarray,
208+
dcr_syn_hol: np.ndarray | None,
209+
dcr_trn_hol: np.ndarray | None,
122210
workspace: TemporaryWorkspace,
123211
) -> None:
124-
fig = plot_distances("Cumulative Distributions of Distance to Closest Records (DCR)", dcr_trn, dcr_hol)
212+
fig = plot_distances(
213+
"Cumulative Distributions of Distance to Closest Records (DCR)", dcr_syn_trn, dcr_syn_hol, dcr_trn_hol
214+
)
125215
workspace.store_figure_html(fig, "distances_dcr")

mostlyai/qa/assets/html/report_template.html

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,11 @@ <h2 id="distances" class="anchor">Distances</h2>
393393
<table class='table' style="text-align: left">
394394
<thead>
395395
<tr>
396-
<td style="width: 33%"> </td>
397-
<td style="width: 33%">Synthetic vs. Training Data</td>
396+
<td style="width: 25%"> </td>
397+
<td style="width: 25%">Synthetic vs. Training Data</td>
398398
{% if metrics.distances.ims_holdout is not none %}
399-
<td style="width: 33%"><small class="muted-text">(Synthetic vs. Holdout Data)</small></td>
399+
<td style="width: 25%"><small style="color: #666666;">Synthetic vs. Holdout Data</small></td>
400+
<td style="width: 25%"><small style="color: #999999;">Training vs. Holdout Data</small></td>
400401
{% endif %}
401402
</tr>
402403
</thead>
@@ -405,16 +406,26 @@ <h2 id="distances" class="anchor">Distances</h2>
405406
<td>Identical Matches</td>
406407
<td>{{ "{:.1%}".format(metrics.distances.ims_training) }}</td>
407408
{% if metrics.distances.ims_holdout is not none %}
408-
<td><small class="muted-text">({{ "{:.1%}".format(metrics.distances.ims_holdout) }})</small></td>
409+
<td><small style="color: #666666;">{{ "{:.1%}".format(metrics.distances.ims_holdout) }}</small></td>
410+
<td><small style="color: #999999;">{{ "{:.1%}".format(metrics.distances.ims_trn_hol) if metrics.distances.ims_trn_hol is not none else "N/A" }}</small></td>
409411
{% endif %}
410412
</tr>
411413
<tr>
412414
<td>Average Distances</td>
413415
<td>{{ "{:.3f}".format(metrics.distances.dcr_training) }}</td>
414416
{% if metrics.distances.dcr_holdout is not none %}
415-
<td><small class="muted-text">({{ "{:.3f}".format(metrics.distances.dcr_holdout) }})</small></td>
417+
<td><small style="color: #666666;">{{ "{:.3f}".format(metrics.distances.dcr_holdout) }}</small></td>
418+
<td><small style="color: #999999;">{{ "{:.3f}".format(metrics.distances.dcr_trn_hol) if metrics.distances.dcr_trn_hol is not none else "N/A" }}</small></td>
416419
{% endif %}
417420
</tr>
421+
{% if metrics.distances.dcr_share is not none %}
422+
<tr>
423+
<td>DCR Share</td>
424+
<td>{{ "{:.1%}".format(metrics.distances.dcr_share) }}</td>
425+
<td></td>
426+
<td></td>
427+
</tr>
428+
{% endif %}
418429
</tbody>
419430
</table>
420431
<br />
@@ -432,9 +443,9 @@ <h2 id="distances" class="anchor">Distances</h2>
432443
<div class="explainer-body">
433444
Synthetic data shall be as close to the original training samples, as it is close to original holdout samples, which serve us as a reference.
434445
This can be asserted empirically by measuring distances between synthetic samples to their closest original samples, whereas training and holdout sets are sampled to be of equal size.
435-
For the visualization above, the distances of synthetic samples to the training samples are displayed in green, and the distances of synthetic samples to the holdout samples (if available) displayed in gray.
436-
A green line that is significantly left of the gray line implies that synthetic samples are closer to the training samples than to the holdout samples, indicating that the data has overfitted to the training data.
437-
A green line that overlays with the gray line validates that the trained model indeed represents the general rules, that can be found in training just as well as in holdout samples.
446+
A green line that is significantly left of the dark gray line implies that synthetic samples are closer to the training samples than to the holdout samples, indicating that the data has overfitted to the training data.
447+
A green line that overlays with the dark gray line validates that the trained model indeed represents the general rules, that can be found in training just as well as in holdout samples.
448+
The DCR share indicates the proportion of synthetic samples that are closer to a training sample than to a holdout sample, and ideally, this value should not significantly exceed 50%, as a higher value could indicate overfitting.
438449
</div>
439450
</div>
440451
</div>

mostlyai/qa/metrics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ class Distances(CustomBaseModel):
192192
"`ims_training`.",
193193
ge=0.0,
194194
)
195+
ims_trn_hol: float | None = Field(
196+
default=None,
197+
alias="imsTrnHol",
198+
description="Share of training samples that are identical to a holdout sample.",
199+
ge=0.0,
200+
)
195201
dcr_training: float | None = Field(
196202
default=None,
197203
alias="dcrTraining",
@@ -201,8 +207,13 @@ class Distances(CustomBaseModel):
201207
dcr_holdout: float | None = Field(
202208
default=None,
203209
alias="dcrHoldout",
204-
description="Average L2 nearest-neighbor distance between synthetic and holdout samples. Serves as a "
205-
"reference for `dcr_training`.",
210+
description="Average L2 nearest-neighbor distance between synthetic and holdout samples. Serves as a reference for `dcr_training`.",
211+
ge=0.0,
212+
)
213+
dcr_trn_hol: float | None = Field(
214+
default=None,
215+
alias="dcrTrnHol",
216+
description="Average L2 nearest-neighbor distance between training and holdout samples. Serves as a reference for `dcr_training`.",
206217
ge=0.0,
207218
)
208219
dcr_share: float | None = Field(

0 commit comments

Comments
 (0)