4242 ProgressCallback ,
4343 PrerequisiteNotMetError ,
4444 check_min_sample_size ,
45- add_tqdm ,
4645 NXT_COLUMN ,
4746 CTX_COLUMN_PREFIX ,
4847 TGT_COLUMN_PREFIX ,
4948 REPORT_CREDITS ,
49+ ProgressCallbackWrapper ,
5050)
5151from mostlyai .qa .filesystem import Statistics , TemporaryWorkspace
5252
@@ -71,7 +71,7 @@ def report(
7171 max_sample_size_accuracy : int | None = None ,
7272 max_sample_size_embeddings : int | None = None ,
7373 statistics_path : str | Path | None = None ,
74- on_progress : ProgressCallback | None = None ,
74+ update_progress : ProgressCallback | None = None ,
7575) -> tuple [Path , Metrics | None ]:
7676 """
7777 Generate HTML report and metrics for comparing synthetic and original data samples.
@@ -93,7 +93,7 @@ def report(
9393 max_sample_size_accuracy: Max sample size for accuracy
9494 max_sample_size_embeddings: Max sample size for embeddings (similarity & distances)
9595 statistics_path: Path of where to store the statistics to be used by `report_from_statistics`
96- on_progress : A custom progress callback
96+ update_progress : A custom progress callback
9797 Returns:
9898 1. Path to the HTML report
9999 2. Pydantic Metrics:
@@ -119,10 +119,10 @@ def report(
119119 - `dcr_share`: Share of synthetic samples that are closer to a training sample than to a holdout sample. This shall not be significantly larger than 50\%.
120120 """
121121
122- with TemporaryWorkspace () as workspace :
123- on_progress = add_tqdm ( on_progress , description = "Creating report" )
124- on_progress ( current = 0 , total = 100 )
125-
122+ with (
123+ TemporaryWorkspace () as workspace ,
124+ ProgressCallbackWrapper ( update_progress , description = "Create report 🚀" ) as progress ,
125+ ):
126126 # ensure all columns are present and in the same order as training data
127127 syn_tgt_data = syn_tgt_data [trn_tgt_data .columns ]
128128 if hol_tgt_data is not None :
@@ -165,7 +165,6 @@ def report(
165165 _LOG .info (err )
166166 statistics .mark_early_exit ()
167167 html_report .store_early_exit_report (report_path )
168- on_progress (current = 100 , total = 100 )
169168 return report_path , None
170169
171170 # prepare datasets for accuracy
@@ -194,7 +193,7 @@ def report(
194193 max_sample_size = max_sample_size_accuracy ,
195194 setup = setup ,
196195 )
197- on_progress ( current = 5 , total = 100 )
196+ progress . update ( completed = 5 , total = 100 )
198197
199198 _LOG .info ("prepare training data for accuracy started" )
200199 trn = pull_data_for_accuracy (
@@ -205,7 +204,7 @@ def report(
205204 max_sample_size = max_sample_size_accuracy ,
206205 setup = setup ,
207206 )
208- on_progress ( current = 10 , total = 100 )
207+ progress . update ( completed = 10 , total = 100 )
209208
210209 # coerce dtypes to match the original training data dtypes
211210 for col in trn :
@@ -222,7 +221,7 @@ def report(
222221 statistics = statistics ,
223222 workspace = workspace ,
224223 )
225- on_progress ( current = 20 , total = 100 )
224+ progress . update ( completed = 20 , total = 100 )
226225
227226 # ensure that embeddings are all equal size for a fair 3-way comparison
228227 max_sample_size_embeddings = min (
@@ -232,7 +231,9 @@ def report(
232231 hol_sample_size or float ("inf" ),
233232 )
234233
235- def _calc_pull_embeds (df_tgt : pd .DataFrame , df_ctx : pd .DataFrame , start : int , stop : int ) -> np .ndarray :
234+ def _calc_pull_embeds (
235+ df_tgt : pd .DataFrame , df_ctx : pd .DataFrame , progress_from : int , progress_to : int
236+ ) -> np .ndarray :
236237 strings = pull_data_for_embeddings (
237238 df_tgt = df_tgt ,
238239 df_ctx = df_ctx ,
@@ -241,24 +242,24 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
241242 max_sample_size = max_sample_size_embeddings ,
242243 )
243244 # split into buckets for calculating embeddings to avoid memory issues and report continuous progress
244- buckets = np .array_split (strings , stop - start )
245+ buckets = np .array_split (strings , progress_to - progress_from )
245246 buckets = [b for b in buckets if len (b ) > 0 ]
246247 embeds = []
247248 for i , bucket in enumerate (buckets , 1 ):
248249 embeds += [calculate_embeddings (bucket .tolist ())]
249- on_progress ( current = start + i , total = 100 )
250- on_progress ( current = stop , total = 100 )
250+ progress . update ( completed = progress_from + i , total = 100 )
251+ progress . update ( completed = progress_to , total = 100 )
251252 embeds = np .concatenate (embeds , axis = 0 )
252253 _LOG .info (f"calculated embeddings { embeds .shape } " )
253254 return embeds
254255
255- syn_embeds = _calc_pull_embeds (df_tgt = syn_tgt_data , df_ctx = syn_ctx_data , start = 20 , stop = 40 )
256- trn_embeds = _calc_pull_embeds (df_tgt = trn_tgt_data , df_ctx = trn_ctx_data , start = 40 , stop = 60 )
256+ syn_embeds = _calc_pull_embeds (df_tgt = syn_tgt_data , df_ctx = syn_ctx_data , progress_from = 20 , progress_to = 40 )
257+ trn_embeds = _calc_pull_embeds (df_tgt = trn_tgt_data , df_ctx = trn_ctx_data , progress_from = 40 , progress_to = 60 )
257258 if hol_tgt_data is not None :
258- hol_embeds = _calc_pull_embeds (df_tgt = hol_tgt_data , df_ctx = hol_ctx_data , start = 60 , stop = 80 )
259+ hol_embeds = _calc_pull_embeds (df_tgt = hol_tgt_data , df_ctx = hol_ctx_data , progress_from = 60 , progress_to = 80 )
259260 else :
260261 hol_embeds = None
261- on_progress ( current = 80 , total = 100 )
262+ progress . update ( completed = 80 , total = 100 )
262263
263264 _LOG .info ("report similarity" )
264265 sim_cosine_trn_hol , sim_cosine_trn_syn , sim_auc_trn_hol , sim_auc_trn_syn = report_similarity (
@@ -268,7 +269,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
268269 workspace = workspace ,
269270 statistics = statistics ,
270271 )
271- on_progress ( current = 90 , total = 100 )
272+ progress . update ( completed = 90 , total = 100 )
272273
273274 _LOG .info ("report distances" )
274275 dcr_trn , dcr_hol = report_distances (
@@ -277,7 +278,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
277278 hol_embeds = hol_embeds ,
278279 workspace = workspace ,
279280 )
280- on_progress ( current = 99 , total = 100 )
281+ progress . update ( completed = 99 , total = 100 )
281282
282283 metrics = calculate_metrics (
283284 acc_uni = acc_uni ,
@@ -314,7 +315,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
314315 acc_biv = acc_biv ,
315316 corr_trn = corr_trn ,
316317 )
317- on_progress ( current = 100 , total = 100 )
318+ progress . update ( completed = 100 , total = 100 )
318319 return report_path , metrics
319320
320321
0 commit comments