Skip to content

Commit 57fdbbd

Browse files
MSD-XXX: new interface for progress reporting (#25)
1 parent 8a5e8b2 commit 57fdbbd

File tree

5 files changed

+76
-49
lines changed

5 files changed

+76
-49
lines changed

poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ joblib = ">=1.2.0"
2323
Jinja2 = ">=3.1.2"
2424
scikit-learn = ">=1.4.0"
2525
sentence-transformers = ">=3.1.0"
26+
rich = "^13.9.4"
2627

2728
[tool.poetry.group.dev.dependencies]
2829
ruff = "0.7.0"

src/mostlyai/qa/common.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Protocol
16+
from functools import partial
17+
from typing import Protocol, Callable
1718

1819
import pandas as pd
19-
from tqdm.auto import tqdm
20+
from rich.progress import Progress
2021

2122
from mostlyai.qa.filesystem import Statistics
2223

@@ -73,18 +74,43 @@ class PrerequisiteNotMetError(Exception):
7374

7475

7576
class ProgressCallback(Protocol):
76-
def __call__(self, current: int, total: int) -> None: ...
77-
78-
79-
def add_tqdm(on_progress: ProgressCallback | None = None, description: str = "Processing") -> ProgressCallback:
80-
pbar = tqdm(desc=description, total=100)
81-
82-
def _on_progress(current: int, total: int):
83-
if on_progress is not None:
84-
on_progress(current, total)
85-
pbar.update(current - pbar.n)
86-
87-
return _on_progress
77+
def __call__(self, total: float | None = None, completed: float | None = None, **kwargs) -> None: ...
78+
79+
80+
class ProgressCallbackWrapper:
81+
@staticmethod
82+
def _wrap_progress_callback(
83+
update_progress: ProgressCallback | None = None, **kwargs
84+
) -> tuple[ProgressCallback, Callable]:
85+
if not update_progress:
86+
rich_progress = Progress()
87+
rich_progress.start()
88+
task_id = rich_progress.add_task(**kwargs)
89+
update_progress = partial(rich_progress.update, task_id=task_id)
90+
else:
91+
rich_progress = None
92+
93+
def teardown_progress():
94+
if rich_progress:
95+
rich_progress.refresh()
96+
rich_progress.stop()
97+
98+
return update_progress, teardown_progress
99+
100+
def update(self, total: float | None = None, completed: float | None = None, **kwargs) -> None:
101+
self._update_progress(total=total, completed=completed, **kwargs)
102+
103+
def __init__(self, update_progress: ProgressCallback | None = None, **kwargs):
104+
self._update_progress, self._teardown_progress = self._wrap_progress_callback(update_progress, **kwargs)
105+
106+
def __enter__(self):
107+
self._update_progress(completed=0, total=1)
108+
return self
109+
110+
def __exit__(self, exc_type, exc_value, traceback):
111+
if exc_type is None:
112+
self._update_progress(completed=1, total=1)
113+
self._teardown_progress()
88114

89115

90116
def check_min_sample_size(size: int, min: int, type: str) -> None:

src/mostlyai/qa/report.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
ProgressCallback,
4343
PrerequisiteNotMetError,
4444
check_min_sample_size,
45-
add_tqdm,
4645
NXT_COLUMN,
4746
CTX_COLUMN_PREFIX,
4847
TGT_COLUMN_PREFIX,
4948
REPORT_CREDITS,
49+
ProgressCallbackWrapper,
5050
)
5151
from mostlyai.qa.filesystem import Statistics, TemporaryWorkspace
5252

@@ -71,7 +71,7 @@ def report(
7171
max_sample_size_accuracy: int | None = None,
7272
max_sample_size_embeddings: int | None = None,
7373
statistics_path: str | Path | None = None,
74-
on_progress: ProgressCallback | None = None,
74+
update_progress: ProgressCallback | None = None,
7575
) -> tuple[Path, Metrics | None]:
7676
"""
7777
Generate HTML report and metrics for comparing synthetic and original data samples.
@@ -93,7 +93,7 @@ def report(
9393
max_sample_size_accuracy: Max sample size for accuracy
9494
max_sample_size_embeddings: Max sample size for embeddings (similarity & distances)
9595
statistics_path: Path of where to store the statistics to be used by `report_from_statistics`
96-
on_progress: A custom progress callback
96+
update_progress: A custom progress callback
9797
Returns:
9898
1. Path to the HTML report
9999
2. Pydantic Metrics:
@@ -119,10 +119,10 @@ def report(
119119
- `dcr_share`: Share of synthetic samples that are closer to a training sample than to a holdout sample. This shall not be significantly larger than 50\%.
120120
"""
121121

122-
with TemporaryWorkspace() as workspace:
123-
on_progress = add_tqdm(on_progress, description="Creating report")
124-
on_progress(current=0, total=100)
125-
122+
with (
123+
TemporaryWorkspace() as workspace,
124+
ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress,
125+
):
126126
# ensure all columns are present and in the same order as training data
127127
syn_tgt_data = syn_tgt_data[trn_tgt_data.columns]
128128
if hol_tgt_data is not None:
@@ -165,7 +165,6 @@ def report(
165165
_LOG.info(err)
166166
statistics.mark_early_exit()
167167
html_report.store_early_exit_report(report_path)
168-
on_progress(current=100, total=100)
169168
return report_path, None
170169

171170
# prepare datasets for accuracy
@@ -194,7 +193,7 @@ def report(
194193
max_sample_size=max_sample_size_accuracy,
195194
setup=setup,
196195
)
197-
on_progress(current=5, total=100)
196+
progress.update(completed=5, total=100)
198197

199198
_LOG.info("prepare training data for accuracy started")
200199
trn = pull_data_for_accuracy(
@@ -205,7 +204,7 @@ def report(
205204
max_sample_size=max_sample_size_accuracy,
206205
setup=setup,
207206
)
208-
on_progress(current=10, total=100)
207+
progress.update(completed=10, total=100)
209208

210209
# coerce dtypes to match the original training data dtypes
211210
for col in trn:
@@ -222,7 +221,7 @@ def report(
222221
statistics=statistics,
223222
workspace=workspace,
224223
)
225-
on_progress(current=20, total=100)
224+
progress.update(completed=20, total=100)
226225

227226
# ensure that embeddings are all equal size for a fair 3-way comparison
228227
max_sample_size_embeddings = min(
@@ -232,7 +231,9 @@ def report(
232231
hol_sample_size or float("inf"),
233232
)
234233

235-
def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, stop: int) -> np.ndarray:
234+
def _calc_pull_embeds(
235+
df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, progress_from: int, progress_to: int
236+
) -> np.ndarray:
236237
strings = pull_data_for_embeddings(
237238
df_tgt=df_tgt,
238239
df_ctx=df_ctx,
@@ -241,24 +242,24 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
241242
max_sample_size=max_sample_size_embeddings,
242243
)
243244
# split into buckets for calculating embeddings to avoid memory issues and report continuous progress
244-
buckets = np.array_split(strings, stop - start)
245+
buckets = np.array_split(strings, progress_to - progress_from)
245246
buckets = [b for b in buckets if len(b) > 0]
246247
embeds = []
247248
for i, bucket in enumerate(buckets, 1):
248249
embeds += [calculate_embeddings(bucket.tolist())]
249-
on_progress(current=start + i, total=100)
250-
on_progress(current=stop, total=100)
250+
progress.update(completed=progress_from + i, total=100)
251+
progress.update(completed=progress_to, total=100)
251252
embeds = np.concatenate(embeds, axis=0)
252253
_LOG.info(f"calculated embeddings {embeds.shape}")
253254
return embeds
254255

255-
syn_embeds = _calc_pull_embeds(df_tgt=syn_tgt_data, df_ctx=syn_ctx_data, start=20, stop=40)
256-
trn_embeds = _calc_pull_embeds(df_tgt=trn_tgt_data, df_ctx=trn_ctx_data, start=40, stop=60)
256+
syn_embeds = _calc_pull_embeds(df_tgt=syn_tgt_data, df_ctx=syn_ctx_data, progress_from=20, progress_to=40)
257+
trn_embeds = _calc_pull_embeds(df_tgt=trn_tgt_data, df_ctx=trn_ctx_data, progress_from=40, progress_to=60)
257258
if hol_tgt_data is not None:
258-
hol_embeds = _calc_pull_embeds(df_tgt=hol_tgt_data, df_ctx=hol_ctx_data, start=60, stop=80)
259+
hol_embeds = _calc_pull_embeds(df_tgt=hol_tgt_data, df_ctx=hol_ctx_data, progress_from=60, progress_to=80)
259260
else:
260261
hol_embeds = None
261-
on_progress(current=80, total=100)
262+
progress.update(completed=80, total=100)
262263

263264
_LOG.info("report similarity")
264265
sim_cosine_trn_hol, sim_cosine_trn_syn, sim_auc_trn_hol, sim_auc_trn_syn = report_similarity(
@@ -268,7 +269,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
268269
workspace=workspace,
269270
statistics=statistics,
270271
)
271-
on_progress(current=90, total=100)
272+
progress.update(completed=90, total=100)
272273

273274
_LOG.info("report distances")
274275
dcr_trn, dcr_hol = report_distances(
@@ -277,7 +278,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
277278
hol_embeds=hol_embeds,
278279
workspace=workspace,
279280
)
280-
on_progress(current=99, total=100)
281+
progress.update(completed=99, total=100)
281282

282283
metrics = calculate_metrics(
283284
acc_uni=acc_uni,
@@ -314,7 +315,7 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
314315
acc_biv=acc_biv,
315316
corr_trn=corr_trn,
316317
)
317-
on_progress(current=100, total=100)
318+
progress.update(completed=100, total=100)
318319
return report_path, metrics
319320

320321

src/mostlyai/qa/report_from_statistics.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
ProgressCallback,
2727
PrerequisiteNotMetError,
2828
check_min_sample_size,
29-
add_tqdm,
3029
check_statistics_prerequisite,
3130
determine_data_size,
3231
REPORT_CREDITS,
32+
ProgressCallbackWrapper,
3333
)
3434
from mostlyai.qa.filesystem import Statistics, TemporaryWorkspace
3535

@@ -50,12 +50,12 @@ def report_from_statistics(
5050
report_extra_info: str = "",
5151
max_sample_size_accuracy: int | None = None,
5252
max_sample_size_embeddings: int | None = None,
53-
on_progress: ProgressCallback | None = None,
53+
update_progress: ProgressCallback | None = None,
5454
) -> Path:
55-
with TemporaryWorkspace() as workspace:
56-
on_progress = add_tqdm(on_progress, description="Creating report from statistics")
57-
on_progress(current=0, total=100)
58-
55+
with (
56+
TemporaryWorkspace() as workspace,
57+
ProgressCallbackWrapper(update_progress, description="Create report 🚀") as progress,
58+
):
5959
# prepare report_path
6060
if report_path is None:
6161
report_path = Path.cwd() / "data-report.html"
@@ -73,7 +73,6 @@ def report_from_statistics(
7373
check_min_sample_size(syn_sample_size, 100, "synthetic")
7474
except PrerequisiteNotMetError:
7575
html_report.store_early_exit_report(report_path)
76-
on_progress(current=100, total=100)
7776
return report_path
7877

7978
meta = statistics.load_meta()
@@ -96,15 +95,15 @@ def report_from_statistics(
9695
max_sample_size=max_sample_size_accuracy,
9796
)
9897
_LOG.info(f"sample synthetic data finished ({syn.shape=})")
99-
on_progress(current=20, total=100)
98+
progress.update(completed=20, total=100)
10099

101100
# calculate and plot accuracy and correlations
102101
acc_uni, acc_biv, corr_trn = report_accuracy_and_correlations_from_statistics(
103102
syn=syn,
104103
statistics=statistics,
105104
workspace=workspace,
106105
)
107-
on_progress(current=30, total=100)
106+
progress.update(completed=30, total=100)
108107

109108
_LOG.info("calculate embeddings for synthetic")
110109
syn_embeds = calculate_embeddings(
@@ -123,7 +122,7 @@ def report_from_statistics(
123122
workspace=workspace,
124123
statistics=statistics,
125124
)
126-
on_progress(current=50, total=100)
125+
progress.update(completed=50, total=100)
127126

128127
meta |= {
129128
"rows_synthetic": syn.shape[0],
@@ -144,7 +143,7 @@ def report_from_statistics(
144143
acc_biv=acc_biv,
145144
corr_trn=corr_trn,
146145
)
147-
on_progress(current=100, total=100)
146+
progress.update(completed=100, total=100)
148147
return report_path
149148

150149

0 commit comments

Comments
 (0)