Skip to content

Commit c0ddf9f

Browse files
committed
add progress bar to sampler
1 parent d8a6110 commit c0ddf9f

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/chemnlp/data/sampler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import yaml
1414
import json
1515
from loguru import logger
16+
from tqdm import tqdm
1617

1718

1819
class TemplateSampler:
@@ -836,8 +837,11 @@ def export(self, output_dir: str, template: str) -> pd.DataFrame:
836837
)
837838
for split in self.df["split"].unique():
838839
df_split = self.df[self.df["split"] == split]
839-
samples = [self.sample(row, template) for _, row in df_split.iterrows()]
840-
840+
samples = []
841+
for _, row in tqdm(df_split.iterrows(), total=len(df_split)):
842+
sample_dict = row.to_dict()
843+
sample = self._fill_template(template, sample_dict)
844+
samples.append(sample)
841845
df_out = pd.DataFrame(samples)
842846

843847
# if self.benchmarking_templates:

0 commit comments

Comments
 (0)