|
30 | 30 |
|
31 | 31 | def calculate_distances( |
32 | 32 | *, syn_embeds: np.ndarray, trn_embeds: np.ndarray, hol_embeds: np.ndarray | None |
33 | | -) -> tuple[np.ndarray, np.ndarray | None]: |
34 | | - """ |
35 | | - Calculates distances to the closest records (DCR). Once for synthetic to training, and once for synthetic to |
36 | | - holdout data. |
| 33 | +) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]: |
37 | 34 | """ |
| 35 | + Calculates distances to the closest records (DCR). |
| 36 | +
|
| 37 | + Args: |
| 38 | + syn_embeds: Embeddings of synthetic data. |
| 39 | + trn_embeds: Embeddings of training data. |
| 40 | + hol_embeds: Embeddings of holdout data. |
38 | 41 |
|
| 42 | + Returns: |
| 43 | + Tuple containing: |
| 44 | + - dcr_syn_trn: DCR for synthetic to training. |
| 45 | + - dcr_syn_hol: DCR for synthetic to holdout. |
| 46 | + - dcr_trn_hol: DCR for training to holdout. |
| 47 | + """ |
39 | 48 | if hol_embeds is not None: |
40 | 49 | assert trn_embeds.shape == hol_embeds.shape |
41 | | - # calculate DCR using L2 metric |
42 | | - index = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16)) |
43 | | - index.fit(syn_embeds) |
| 50 | + # calculate DCR for synthetic to training |
| 51 | + index_syn = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16)) |
| 52 | + index_syn.fit(syn_embeds) |
44 | 53 | _LOG.info(f"calculate DCRs for {len(syn_embeds):,} synthetic to {len(trn_embeds):,} training") |
45 | | - dcrs_trn, _ = index.kneighbors(trn_embeds) |
46 | | - dcr_trn = dcrs_trn[:, 0] |
| 54 | + dcrs_syn_trn, _ = index_syn.kneighbors(trn_embeds) |
| 55 | + dcr_syn_trn = dcrs_syn_trn[:, 0] |
| 56 | + |
| 57 | + dcr_syn_hol = None |
| 58 | + dcr_trn_hol = None |
| 59 | + |
47 | 60 | if hol_embeds is not None: |
| 61 | + # calculate DCR for synthetic to holdout |
48 | 62 | _LOG.info(f"calculate DCRs for {len(syn_embeds):,} synthetic to {len(hol_embeds):,} holdout") |
49 | | - dcrs_hol, _ = index.kneighbors(hol_embeds) |
50 | | - dcr_hol = dcrs_hol[:, 0] |
51 | | - else: |
52 | | - dcr_hol = None |
53 | | - dcr_trn_deciles = np.round(np.quantile(dcr_trn, np.linspace(0, 1, 11)), 3) |
54 | | - _LOG.info(f"DCR deciles for synthetic to training: {dcr_trn_deciles}") |
55 | | - if dcr_hol is not None: |
56 | | - dcr_hol_deciles = np.round(np.quantile(dcr_hol, np.linspace(0, 1, 11)), 3) |
57 | | - _LOG.info(f"DCR deciles for synthetic to holdout: {dcr_hol_deciles}") |
58 | | - # calculate share of dcr_trn != dcr_hol |
59 | | - _LOG.info(f"share of dcr_trn < dcr_hol: {np.mean(dcr_trn < dcr_hol):.1%}") |
60 | | - _LOG.info(f"share of dcr_trn > dcr_hol: {np.mean(dcr_trn > dcr_hol):.1%}") |
61 | | - return dcr_trn, dcr_hol |
62 | | - |
63 | | - |
64 | | -def plot_distances(plot_title: str, dcr_trn: np.ndarray, dcr_hol: np.ndarray | None) -> go.Figure: |
| 63 | + dcrs_syn_hol, _ = index_syn.kneighbors(hol_embeds) |
| 64 | + dcr_syn_hol = dcrs_syn_hol[:, 0] |
| 65 | + |
| 66 | + # calculate DCR for training to holdout |
| 67 | + _LOG.info(f"calculate DCRs for {len(trn_embeds):,} training to {len(hol_embeds):,} holdout") |
| 68 | + index_trn = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="l2", n_jobs=min(cpu_count() - 1, 16)) |
| 69 | + index_trn.fit(trn_embeds) |
| 70 | + dcrs_trn_hol, _ = index_trn.kneighbors(hol_embeds) |
| 71 | + dcr_trn_hol = dcrs_trn_hol[:, 0] |
| 72 | + |
| 73 | + dcr_syn_trn_deciles = np.round(np.quantile(dcr_syn_trn, np.linspace(0, 1, 11)), 3) |
| 74 | + _LOG.info(f"DCR deciles for synthetic to training: {dcr_syn_trn_deciles}") |
| 75 | + if dcr_syn_hol is not None: |
| 76 | + dcr_syn_hol_deciles = np.round(np.quantile(dcr_syn_hol, np.linspace(0, 1, 11)), 3) |
| 77 | + _LOG.info(f"DCR deciles for synthetic to holdout: {dcr_syn_hol_deciles}") |
| 78 | + # calculate share of dcr_syn_trn != dcr_syn_hol |
| 79 | + _LOG.info(f"share of dcr_syn_trn < dcr_syn_hol: {np.mean(dcr_syn_trn < dcr_syn_hol):.1%}") |
| 80 | + _LOG.info(f"share of dcr_syn_trn > dcr_syn_hol: {np.mean(dcr_syn_trn > dcr_syn_hol):.1%}") |
| 81 | + |
| 82 | + if dcr_trn_hol is not None: |
| 83 | + dcr_trn_hol_deciles = np.round(np.quantile(dcr_trn_hol, np.linspace(0, 1, 11)), 3) |
| 84 | + _LOG.info(f"DCR deciles for training to holdout: {dcr_trn_hol_deciles}") |
| 85 | + |
| 86 | + return dcr_syn_trn, dcr_syn_hol, dcr_trn_hol |
| 87 | + |
| 88 | + |
| 89 | +def plot_distances( |
| 90 | + plot_title: str, dcr_syn_trn: np.ndarray, dcr_syn_hol: np.ndarray | None, dcr_trn_hol: np.ndarray | None |
| 91 | +) -> go.Figure: |
65 | 92 | # calculate quantiles |
66 | 93 | y = np.linspace(0, 1, 101) |
67 | | - x_trn = np.quantile(dcr_trn, y) |
68 | | - if dcr_hol is not None: |
69 | | - x_hol = np.quantile(dcr_hol, y) |
| 94 | + x_syn_trn = np.quantile(dcr_syn_trn, y) |
| 95 | + if dcr_syn_hol is not None: |
| 96 | + x_syn_hol = np.quantile(dcr_syn_hol, y) |
| 97 | + else: |
| 98 | + x_syn_hol = None |
| 99 | + |
| 100 | + if dcr_trn_hol is not None: |
| 101 | + x_trn_hol = np.quantile(dcr_trn_hol, y) |
70 | 102 | else: |
71 | | - x_hol = None |
| 103 | + x_trn_hol = None |
| 104 | + |
72 | 105 | # prepare layout |
73 | 106 | layout = go.Layout( |
74 | 107 | title=dict(text=f"<b>{plot_title}</b>", x=0.5, y=0.98), |
75 | 108 | title_font=CHARTS_FONTS["title"], |
76 | 109 | font=CHARTS_FONTS["base"], |
77 | | - hoverlabel=CHARTS_FONTS["hover"], |
| 110 | + hoverlabel=dict( |
| 111 | + **CHARTS_FONTS["hover"], |
| 112 | + namelength=-1, # Show full length of hover labels |
| 113 | + ), |
78 | 114 | plot_bgcolor=CHARTS_COLORS["background"], |
79 | 115 | autosize=True, |
80 | 116 | height=500, |
81 | 117 | margin=dict(l=20, r=20, b=20, t=40, pad=5), |
82 | | - showlegend=False, |
83 | | - hovermode="x unified", |
| 118 | + showlegend=True, |
84 | 119 | yaxis=dict( |
85 | 120 | showticklabels=False, |
86 | 121 | zeroline=True, |
87 | 122 | zerolinewidth=1, |
88 | 123 | zerolinecolor="#999999", |
89 | 124 | rangemode="tozero", |
| 125 | + showline=True, |
| 126 | + linewidth=1, |
| 127 | + linecolor="#999999", |
| 128 | + ), |
| 129 | + yaxis2=dict( |
| 130 | + overlaying="y", |
| 131 | + side="right", |
| 132 | + tickformat=".0%", |
| 133 | + showgrid=False, |
| 134 | + range=[0, 1], |
| 135 | + showline=True, |
| 136 | + linewidth=1, |
| 137 | + linecolor="#999999", |
| 138 | + ), |
| 139 | + xaxis=dict( |
| 140 | + showline=True, |
| 141 | + linewidth=1, |
| 142 | + linecolor="#999999", |
| 143 | + hoverformat=".3f", |
90 | 144 | ), |
91 | 145 | ) |
92 | | - fig = go.Figure(layout=layout).set_subplots( |
93 | | - rows=1, |
94 | | - cols=1, |
95 | | - ) |
96 | | - # plot content |
97 | | - cum_trn_scatter = go.Scatter( |
98 | | - mode="lines", |
99 | | - x=x_trn, |
100 | | - y=y, |
101 | | - name="DCR training", |
102 | | - line=dict(color=CHARTS_COLORS["synthetic"], width=5), |
103 | | - yhoverformat=".0%", |
104 | | - ) |
105 | | - fig.add_trace(cum_trn_scatter, row=1, col=1) |
106 | | - if x_hol is not None: |
107 | | - cum_hol_scatter = go.Scatter( |
| 146 | + fig = go.Figure(layout=layout) |
| 147 | + |
| 148 | + traces = [] |
| 149 | + |
| 150 | + # training vs holdout (light gray) |
| 151 | + if x_trn_hol is not None: |
| 152 | + traces.append( |
| 153 | + go.Scatter( |
| 154 | + mode="lines", |
| 155 | + x=x_trn_hol, |
| 156 | + y=y, |
| 157 | + name="Training vs. Holdout Data", |
| 158 | + line=dict(color="#999999", width=5), |
| 159 | + yaxis="y2", |
| 160 | + ) |
| 161 | + ) |
| 162 | + |
| 163 | + # synthetic vs holdout (gray) |
| 164 | + if x_syn_hol is not None: |
| 165 | + traces.append( |
| 166 | + go.Scatter( |
| 167 | + mode="lines", |
| 168 | + x=x_syn_hol, |
| 169 | + y=y, |
| 170 | + name="Synthetic vs. Holdout Data", |
| 171 | + line=dict(color="#666666", width=5), |
| 172 | + yaxis="y2", |
| 173 | + ) |
| 174 | + ) |
| 175 | + |
| 176 | + # synthetic vs training (green) |
| 177 | + traces.append( |
| 178 | + go.Scatter( |
108 | 179 | mode="lines", |
109 | | - x=x_hol, |
| 180 | + x=x_syn_trn, |
110 | 181 | y=y, |
111 | | - name="DCR holdout", |
112 | | - line=dict(color=CHARTS_COLORS["original"], width=5), |
113 | | - yhoverformat=".0%", |
| 182 | + name="Synthetic vs. Training Data", |
| 183 | + line=dict(color="#24db96", width=5), |
| 184 | + yaxis="y2", |
| 185 | + ) |
| 186 | + ) |
| 187 | + |
| 188 | + for trace in traces: |
| 189 | + fig.add_trace(trace) |
| 190 | + |
| 191 | + fig.update_layout( |
| 192 | + legend=dict( |
| 193 | + orientation="h", |
| 194 | + yanchor="bottom", |
| 195 | + y=-0.15, |
| 196 | + xanchor="center", |
| 197 | + x=0.5, |
| 198 | + font=dict(size=10), |
| 199 | + traceorder="reversed", |
114 | 200 | ) |
115 | | - fig.add_trace(cum_hol_scatter, row=1, col=1) |
| 201 | + ) |
| 202 | + |
116 | 203 | return fig |
117 | 204 |
|
118 | 205 |
|
119 | 206 | def plot_store_distances( |
120 | | - dcr_trn: np.ndarray, |
121 | | - dcr_hol: np.ndarray | None, |
| 207 | + dcr_syn_trn: np.ndarray, |
| 208 | + dcr_syn_hol: np.ndarray | None, |
| 209 | + dcr_trn_hol: np.ndarray | None, |
122 | 210 | workspace: TemporaryWorkspace, |
123 | 211 | ) -> None: |
124 | | - fig = plot_distances("Cumulative Distributions of Distance to Closest Records (DCR)", dcr_trn, dcr_hol) |
| 212 | + fig = plot_distances( |
| 213 | + "Cumulative Distributions of Distance to Closest Records (DCR)", dcr_syn_trn, dcr_syn_hol, dcr_trn_hol |
| 214 | + ) |
125 | 215 | workspace.store_figure_html(fig, "distances_dcr") |
0 commit comments