From 42b28bba9c401a2d8829c6bf16fd29509aa783ff Mon Sep 17 00:00:00 2001 From: Alexander Erben Date: Tue, 14 Oct 2025 07:56:07 -0700 Subject: [PATCH 1/5] minor update for fs2:v0.5+ compatibility --- pyproject.toml | 4 ++-- stopes/modules/preprocess/sonar_text_embedding.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f4626e6..be06108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ classifiers=[ "numba", "transformers", "openai-whisper==20230314", - "fairseq2==0.2.*", + "fairseq2>=0.5.*", "sonar-space==0.2.*", ] vocal_style_sim = [ @@ -87,7 +87,7 @@ classifiers=[ ] sonar_mining = [ "sonar-space==0.2.*", - "fairseq2==0.2.*", + "fairseq2>=0.5.*", ] dev = [ # Test diff --git a/stopes/modules/preprocess/sonar_text_embedding.py b/stopes/modules/preprocess/sonar_text_embedding.py index 8fd83c8..456f11d 100644 --- a/stopes/modules/preprocess/sonar_text_embedding.py +++ b/stopes/modules/preprocess/sonar_text_embedding.py @@ -14,7 +14,7 @@ import pandas as pd import pyarrow as pa import torch -from fairseq2.assets.error import AssetError +from fairseq2.error import OperationalError from retrying import retry from sonar.inference_pipelines.text import ( EmbeddingToTextModelPipeline, @@ -37,7 +37,7 @@ from stopes.utils.sharding.parquet_shards import ParquetOutputConfig fairse2_asset_loading_retry = retry( - retry_on_exception=lambda exception: isinstance(exception, (AssetError, IOError)), + retry_on_exception=lambda exception: isinstance(exception, (OperationalError, IOError)), stop_max_attempt_number=20, wait_random_min=1000, wait_random_max=30_000, From 245936f7f1ce4fbac0f49af19a3e1ca9654b9ae8 Mon Sep 17 00:00:00 2001 From: Alexander Erben Date: Tue, 14 Oct 2025 11:19:04 -0400 Subject: [PATCH 2/5] black & isort --- .github/workflows/lint_and_tests.yaml | 2 ++ demo/toxicity-alti-hb/ETOX/etox.py | 1 - .../analysis/00c_plot_toxicity_per_lang.py | 8 +++-- stopes/core/jobs_registry/registry.py | 1 + .../core/jobs_registry/submitit_slurm_job.py | 4 ++- stopes/core/tests/test_utils.py | 2 +- stopes/eval/alti/alignment/align.py | 5 ++-- .../alti/alti_metrics/nllb_alti_detector.py | 7 +++-- stopes/eval/alti/wrappers/utils.py | 3 +- .../eval/local_prosody/compare_utterances.py | 14 ++++----- stopes/eval/local_prosody/utterance.py | 2 +- .../bitext/indexing/merge_faiss_indexes.py | 16 ++++++---- .../bitext/indexing/populate_faiss_index.py | 8 +++-- stopes/modules/bitext/mining/merge_shards.py | 3 +- .../generate_multi_bleu_detok_module.py | 7 +++-- stopes/modules/nmt_bitext_eval_utils.py | 2 +- stopes/modules/preprocess/bitext_processor.py | 1 - .../preprocess/fairseq_binarizer_encoder.py | 2 +- .../preprocess/laser_sentence_encoder.py | 6 ++-- .../multiproc_fairseq_binarizer_encoder.py | 2 +- .../preprocess/sonar_text_embedding.py | 4 ++- .../preprocess/uromanize_cli_module.py | 5 +++- .../wav2vec_laser_speech_encoder.py | 3 +- stopes/modules/speech/utils.py | 8 +++-- .../video_alignement/video_segmentor.py | 5 ++-- .../modules/tests/test_split_merge_langs.py | 14 +++++---- .../bitext/global_mining_pipeline.py | 10 ++++--- .../distillation_bitext_processor.py | 2 +- .../distillation/distillation_pipeline.py | 2 +- stopes/pipelines/eval/eval_blaser.py | 24 ++++++++------- .../monolingual/monolingual_line_processor.py | 8 ++--- .../monolingual/utils/predict_script.py | 5 +++- .../monolingual/utils/sentence_split.py | 2 +- .../pipelines/prepare_data/dedup_sharding.py | 29 ++++++++++--------- stopes/pipelines/prepare_data/validate.py | 2 +- stopes/utils/cache.py | 2 +- stopes/utils/parquet_dataloader.py | 14 +++++---- stopes/utils/tts_preprocessing/cleaners.py | 4 ++- 38 files changed, 141 insertions(+), 98 deletions(-) diff --git a/.github/workflows/lint_and_tests.yaml b/.github/workflows/lint_and_tests.yaml index 5c1c5e8..6a11f02 100644 --- a/.github/workflows/lint_and_tests.yaml +++ b/.github/workflows/lint_and_tests.yaml @@ -41,6 +41,8 @@ jobs: - name: isort run: isort --check --diff . + - name: black version + run: black --version - name: black run: black --check --diff . - name: pytest diff --git a/demo/toxicity-alti-hb/ETOX/etox.py b/demo/toxicity-alti-hb/ETOX/etox.py index 19bb491..14ce522 100644 --- a/demo/toxicity-alti-hb/ETOX/etox.py +++ b/demo/toxicity-alti-hb/ETOX/etox.py @@ -446,7 +446,6 @@ def etox_paired_file_wrapper( oldcolumns=True, filetype=None, ): - """ file loading/writing wrapper for the paired language toxicity evaluation function. diff --git a/demo/toxicity-alti-hb/analysis/00c_plot_toxicity_per_lang.py b/demo/toxicity-alti-hb/analysis/00c_plot_toxicity_per_lang.py index 5c0efa2..c797820 100644 --- a/demo/toxicity-alti-hb/analysis/00c_plot_toxicity_per_lang.py +++ b/demo/toxicity-alti-hb/analysis/00c_plot_toxicity_per_lang.py @@ -43,9 +43,11 @@ def plot_toxicity_per_lang(): } sorted_axes = sorted(list(axis_colors.keys())) axis_display_names = { - axis: "Race and ethnicity" - if axis == "race_ethnicity" - else axis[0].upper() + axis[1:].replace("_", " ") + axis: ( + "Race and ethnicity" + if axis == "race_ethnicity" + else axis[0].upper() + axis[1:].replace("_", " ") + ) for axis in sorted_axes } sorted_axis_names = sorted(list(axis_colors.keys())) diff --git a/stopes/core/jobs_registry/registry.py b/stopes/core/jobs_registry/registry.py index c09c005..338f8c2 100644 --- a/stopes/core/jobs_registry/registry.py +++ b/stopes/core/jobs_registry/registry.py @@ -14,6 +14,7 @@ logger = logging.getLogger("stopes.jobs") + ################################################################################ # Registry Exceptions ################################################################################ diff --git a/stopes/core/jobs_registry/submitit_slurm_job.py b/stopes/core/jobs_registry/submitit_slurm_job.py index 39d18ac..63ac6bb 100644 --- a/stopes/core/jobs_registry/submitit_slurm_job.py +++ b/stopes/core/jobs_registry/submitit_slurm_job.py @@ -284,7 +284,9 @@ def _convert_slurm_status_into_registry_job_status( return job_status - except KeyError: # Entering this except block means slurm_status doesn't exist in submitit_state_to_registry_state_dict + except ( + KeyError + ): # Entering this except block means slurm_status doesn't exist in submitit_state_to_registry_state_dict logger.warning( f"Job with id: {job_id} has unrecognized slurm status: {slurm_status}. Please inspect and if suitable, add this status to the slurm_state_to_registry_state_map converter." ) diff --git a/stopes/core/tests/test_utils.py b/stopes/core/tests/test_utils.py index 3b41514..8d3c154 100644 --- a/stopes/core/tests/test_utils.py +++ b/stopes/core/tests/test_utils.py @@ -188,7 +188,7 @@ async def test_semaphore(): # make sure that the semaphore blocks execution ends.sort() - for (end1, end2) in zip(ends, ends[1:]): + for end1, end2 in zip(ends, ends[1:]): t_diff = end2 - end1 assert ( t_diff.total_seconds() >= sleep_time diff --git a/stopes/eval/alti/alignment/align.py b/stopes/eval/alti/alignment/align.py index cc99e0d..c1dad1d 100644 --- a/stopes/eval/alti/alignment/align.py +++ b/stopes/eval/alti/alignment/align.py @@ -6,7 +6,7 @@ # This code was adapted from the repository https://github.com/mt-upc/transformer-contributions-nmt by Javier Ferrando. -""" Various utilities for computing word attributions and word alignment quality metrics.""" +"""Various utilities for computing word attributions and word alignment quality metrics.""" import itertools import typing as tp @@ -109,7 +109,8 @@ def compute_alignment_metrics( sure: tp.List[tp.Set], possible: tp.List[tp.Set], hypothesis: tp.List[tp.Set] ) -> tp.Tuple[float, float, float]: """Compute average alignment rate, precision and recall for alignment. - Inputs are lists of alignments. All alignments are presented as sets of (tgt, src) pairs.""" + Inputs are lists of alignments. All alignments are presented as sets of (tgt, src) pairs. + """ sum_a_intersect_p, sum_a_intersect_s, sum_s, sum_a = 0, 0, 0, 0 for s, p, a in itertools.zip_longest(sure, possible, hypothesis): diff --git a/stopes/eval/alti/alti_metrics/nllb_alti_detector.py b/stopes/eval/alti/alti_metrics/nllb_alti_detector.py index 333240e..662401a 100644 --- a/stopes/eval/alti/alti_metrics/nllb_alti_detector.py +++ b/stopes/eval/alti/alti_metrics/nllb_alti_detector.py @@ -104,7 +104,8 @@ def load_bilingual_model( @dataclasses.dataclass class ALTIMetricsConfig: """The config indicating how to load sentence pairs, load the model, - compute the ALTI metrics with it, and save results. - to use with the `compute_nllb_alti` function.""" + compute the ALTI metrics with it, and save results. - to use with the `compute_nllb_alti` function. + """ # the model used to compute ALTI is_multilingual: bool @@ -118,7 +119,9 @@ class ALTIMetricsConfig: Path ] # a .jsonl file with token-level contributions # format and location of the source data - input_filename: Path # the source file with sources and translations; assumed to be .tsv + input_filename: ( + Path # the source file with sources and translations; assumed to be .tsv + ) src_lang: str tgt_lang: str src_col: tp.Union[str, int] = "src" diff --git a/stopes/eval/alti/wrappers/utils.py b/stopes/eval/alti/wrappers/utils.py index 10bd7fc..ba5c5c8 100644 --- a/stopes/eval/alti/wrappers/utils.py +++ b/stopes/eval/alti/wrappers/utils.py @@ -18,7 +18,8 @@ def spearmanr(x, y): """Compute Spearman rank's correlation bertween two attribution vectors. - https://github.com/samiraabnar/attention_flow/blob/master/compute_corel_distilbert_sst.py""" + https://github.com/samiraabnar/attention_flow/blob/master/compute_corel_distilbert_sst.py + """ x = pd.Series(x) y = pd.Series(y) diff --git a/stopes/eval/local_prosody/compare_utterances.py b/stopes/eval/local_prosody/compare_utterances.py index 1cc518b..34d72e3 100644 --- a/stopes/eval/local_prosody/compare_utterances.py +++ b/stopes/eval/local_prosody/compare_utterances.py @@ -124,7 +124,7 @@ def align_pauses( duration_scores = [] alignment_scores = [] - for (pauses, p2a) in [(pp_src, p2a_src), (pp_tgt, p2a_tgt)]: + for pauses, p2a in [(pp_src, p2a_src), (pp_tgt, p2a_tgt)]: for prev_word_id, duration in pauses: if prev_word_id not in p2a: duration_scores.append(0.0) @@ -353,12 +353,12 @@ def aggregate_pause_alignment_statistics(df: pd.DataFrame): "mean_duration_score": df.duration_score.mean(), "mean_alignment_score": df.alignment_score.mean(), "mean_joint_score": joint_score.mean(), - "wmean_duration_score": (df.duration_score * w).sum() / w.sum() - if non_empty - else 1, - "wmean_alignment_score": (df.alignment_score * w).sum() / w.sum() - if non_empty - else 1, + "wmean_duration_score": ( + (df.duration_score * w).sum() / w.sum() if non_empty else 1 + ), + "wmean_alignment_score": ( + (df.alignment_score * w).sum() / w.sum() if non_empty else 1 + ), "wmean_joint_score": (joint_score * w).sum() / w.sum() if non_empty else 1, "total_weight": w.sum(), "n_items": df.shape[0], diff --git a/stopes/eval/local_prosody/utterance.py b/stopes/eval/local_prosody/utterance.py index caa9a6a..c380b62 100644 --- a/stopes/eval/local_prosody/utterance.py +++ b/stopes/eval/local_prosody/utterance.py @@ -124,7 +124,7 @@ def get_text_with_markup(self, min_pause_duration=0.1, min_emph_score=0.5) -> st parts = [] pause_durations = self.get_pauses_after_words(min_duration=min_pause_duration) emphasis_scores = self.emphasis_scores or [0] * len(self.words) - for (word, pause, emph_score) in zip( + for word, pause, emph_score in zip( self.words, pause_durations, emphasis_scores ): if emph_score > min_emph_score: diff --git a/stopes/modules/bitext/indexing/merge_faiss_indexes.py b/stopes/modules/bitext/indexing/merge_faiss_indexes.py index acb2b52..fcc3141 100644 --- a/stopes/modules/bitext/indexing/merge_faiss_indexes.py +++ b/stopes/modules/bitext/indexing/merge_faiss_indexes.py @@ -118,12 +118,16 @@ def checkpoint( return submitit.helpers.DelayedSubmission( MergeFAISSIndexesModule( config=self.config, - checkpoint_part=self.partial_merge_file - if self.config.enable_checkpointing - else None, - checkpoint_file_idx=self.checkpoint_file_idx - if self.config.enable_checkpointing - else None, + checkpoint_part=( + self.partial_merge_file + if self.config.enable_checkpointing + else None + ), + checkpoint_file_idx=( + self.checkpoint_file_idx + if self.config.enable_checkpointing + else None + ), ), *args, **kwargs, diff --git a/stopes/modules/bitext/indexing/populate_faiss_index.py b/stopes/modules/bitext/indexing/populate_faiss_index.py index c7fcdcf..12171a9 100644 --- a/stopes/modules/bitext/indexing/populate_faiss_index.py +++ b/stopes/modules/bitext/indexing/populate_faiss_index.py @@ -202,9 +202,11 @@ def checkpoint( return submitit.helpers.DelayedSubmission( PopulateFAISSIndexModule( config=self.config, - checkpoint_summary=self.checkpoint_summary - if self.config.enable_checkpointing - else None, + checkpoint_summary=( + self.checkpoint_summary + if self.config.enable_checkpointing + else None + ), ), *args, **kwargs, diff --git a/stopes/modules/bitext/mining/merge_shards.py b/stopes/modules/bitext/mining/merge_shards.py index a29bcb0..71d8f9c 100644 --- a/stopes/modules/bitext/mining/merge_shards.py +++ b/stopes/modules/bitext/mining/merge_shards.py @@ -33,7 +33,8 @@ class MergeShardsConfig: class ShardForMerge: """Represent an input shard opened for merge. - Both input and output shards are in decreasing order of match scores, and this object helps manage that.""" + Both input and output shards are in decreasing order of match scores, and this object helps manage that. + """ def __init__(self, text_path: Path, meta_path: tp.Optional[Path]): self.text_path = text_path diff --git a/stopes/modules/evaluation/generate_multi_bleu_detok_module.py b/stopes/modules/evaluation/generate_multi_bleu_detok_module.py index 6bc75d9..8fae6bd 100644 --- a/stopes/modules/evaluation/generate_multi_bleu_detok_module.py +++ b/stopes/modules/evaluation/generate_multi_bleu_detok_module.py @@ -234,9 +234,10 @@ def process_output_ref_hyp_file( return_file = Path(f"{fairseq_generate_output_file}.{file_type}") desired_line_prefix = "T" if file_type == "ref" else "H" desired_col_number = 1 if file_type == "ref" else 2 - with open(fairseq_generate_output_file, "r", encoding="utf-8") as read_file, open( - return_file, "w", encoding="utf-8" - ) as write_file: + with ( + open(fairseq_generate_output_file, "r", encoding="utf-8") as read_file, + open(return_file, "w", encoding="utf-8") as write_file, + ): for line in read_file: if line.startswith(desired_line_prefix): line = line.rstrip("\n") diff --git a/stopes/modules/nmt_bitext_eval_utils.py b/stopes/modules/nmt_bitext_eval_utils.py index be30f41..ea0d312 100644 --- a/stopes/modules/nmt_bitext_eval_utils.py +++ b/stopes/modules/nmt_bitext_eval_utils.py @@ -205,7 +205,7 @@ def concat_public_bitext(self, output: Path, tgt: bool) -> int: """ lines = 0 with utils.open(output, "wb") as o: - for (src_file, tgt_file) in self.public_bitext: + for src_file, tgt_file in self.public_bitext: with utils.open(tgt_file if tgt else src_file, "rb") as f: for line in f: lines += 1 diff --git a/stopes/modules/preprocess/bitext_processor.py b/stopes/modules/preprocess/bitext_processor.py index 767dcfa..0d853c7 100644 --- a/stopes/modules/preprocess/bitext_processor.py +++ b/stopes/modules/preprocess/bitext_processor.py @@ -49,7 +49,6 @@ def __init__( def process_lines( self, dataset_reader: tp.Generator[DatasetLine, None, None] ) -> None: - """ process a batch of lines from two files and writes them to two output_files the way you want. The input are two iterators of lines with their line number in the input file diff --git a/stopes/modules/preprocess/fairseq_binarizer_encoder.py b/stopes/modules/preprocess/fairseq_binarizer_encoder.py index e006d76..e69acdd 100644 --- a/stopes/modules/preprocess/fairseq_binarizer_encoder.py +++ b/stopes/modules/preprocess/fairseq_binarizer_encoder.py @@ -87,7 +87,7 @@ def __enter__(self): def process_lines(self, lines_with_number: tp.Iterator[tp.Tuple[int, str]]) -> None: summary = BinarizeSummary() - for (_, s) in lines_with_number: + for _, s in lines_with_number: self.dataset_builder.add_item(self.binarizer.binarize_line(s, summary)) self.summary.merge(summary) log.info(self.summary) diff --git a/stopes/modules/preprocess/laser_sentence_encoder.py b/stopes/modules/preprocess/laser_sentence_encoder.py index 5fba626..42beb13 100644 --- a/stopes/modules/preprocess/laser_sentence_encoder.py +++ b/stopes/modules/preprocess/laser_sentence_encoder.py @@ -306,9 +306,9 @@ def combine_bidir(outs): return { "sentemb": sentemb, "encoder_out": (x, final_hiddens, final_cells), - "encoder_padding_mask": encoder_padding_mask - if encoder_padding_mask.any() - else None, + "encoder_padding_mask": ( + encoder_padding_mask if encoder_padding_mask.any() else None + ), } diff --git a/stopes/modules/preprocess/multiproc_fairseq_binarizer_encoder.py b/stopes/modules/preprocess/multiproc_fairseq_binarizer_encoder.py index bd4d623..f4bc3eb 100644 --- a/stopes/modules/preprocess/multiproc_fairseq_binarizer_encoder.py +++ b/stopes/modules/preprocess/multiproc_fairseq_binarizer_encoder.py @@ -111,7 +111,7 @@ def __enter__(self): def process_lines(self, lines_with_number: tp.Iterator[tp.Tuple[int, str]]) -> None: summary = BinarizeSummary() - for (_, s) in lines_with_number: + for _, s in lines_with_number: self.dataset_builder.add_item(self.binarizer.binarize_line(s, summary)) self.summary.merge(summary) logger.info(self.summary) diff --git a/stopes/modules/preprocess/sonar_text_embedding.py b/stopes/modules/preprocess/sonar_text_embedding.py index 456f11d..3103e71 100644 --- a/stopes/modules/preprocess/sonar_text_embedding.py +++ b/stopes/modules/preprocess/sonar_text_embedding.py @@ -37,7 +37,9 @@ from stopes.utils.sharding.parquet_shards import ParquetOutputConfig fairse2_asset_loading_retry = retry( - retry_on_exception=lambda exception: isinstance(exception, (OperationalError, IOError)), + retry_on_exception=lambda exception: isinstance( + exception, (OperationalError, IOError) + ), stop_max_attempt_number=20, wait_random_min=1000, wait_random_max=30_000, diff --git a/stopes/modules/preprocess/uromanize_cli_module.py b/stopes/modules/preprocess/uromanize_cli_module.py index 5c4ff6e..ba77d06 100644 --- a/stopes/modules/preprocess/uromanize_cli_module.py +++ b/stopes/modules/preprocess/uromanize_cli_module.py @@ -73,7 +73,10 @@ def run_uroman_cli_standalone(input_file: Path, output_file: Path, lang: str) -> def uromanize(text: tp.List[str]) -> tp.List[str]: if text is None or len(text) == 0: return [] - with tempfile.NamedTemporaryFile() as input_file, tempfile.NamedTemporaryFile() as output_file: + with ( + tempfile.NamedTemporaryFile() as input_file, + tempfile.NamedTemporaryFile() as output_file, + ): with open(input_file.name, "w") as f: for sentence in text: f.write(f"{sentence}\n") diff --git a/stopes/modules/preprocess/wav2vec_laser_speech_encoder.py b/stopes/modules/preprocess/wav2vec_laser_speech_encoder.py index b5e8e90..915314c 100644 --- a/stopes/modules/preprocess/wav2vec_laser_speech_encoder.py +++ b/stopes/modules/preprocess/wav2vec_laser_speech_encoder.py @@ -54,7 +54,8 @@ def __init__( @property def fbank_features(self) -> int: """Number of fbank features to feed to the encoder (0 if instead of fbank it expects a raw waveform). - This parameter is defined based on the architecture of the underlying encoder.""" + This parameter is defined based on the architecture of the underlying encoder. + """ return self.encoder_cfg.task.get("fbank_features", 0) def _encode_batch(self, source, padding_mask): diff --git a/stopes/modules/speech/utils.py b/stopes/modules/speech/utils.py index ff5d452..52b5a15 100644 --- a/stopes/modules/speech/utils.py +++ b/stopes/modules/speech/utils.py @@ -261,9 +261,11 @@ def auto_parse_line(line: str, sampling_factor: tp.Optional[int] = None) -> Line columns = line.split("\t") return LineResult( columns=[ - auto_parse(column, sampling_factor) - if i == 0 - else parse_audio_or_text(column, sampling_factor) + ( + auto_parse(column, sampling_factor) + if i == 0 + else parse_audio_or_text(column, sampling_factor) + ) for (i, column) in enumerate(columns) ] ) diff --git a/stopes/modules/speech/video_alignement/video_segmentor.py b/stopes/modules/speech/video_alignement/video_segmentor.py index f25d668..04e8dae 100644 --- a/stopes/modules/speech/video_alignement/video_segmentor.py +++ b/stopes/modules/speech/video_alignement/video_segmentor.py @@ -115,7 +115,6 @@ class WhisperSegmentorConfig: class WhisperSegmentorModule(StopesModule): - """Extract utterances from an audio and embed them This module is multi-lingual : different languages can be processed with the same pipeline @@ -382,7 +381,9 @@ def compute_whisper_segmentation( logger.info( f"Starting Whisper segmentation on wav of length {round(len(wav) / self.SR / 60, 3)} minuntes in lang = {lang}" ) - with tempfile.TemporaryDirectory() as data_gym_cache: # attempt to fixe loading issue + with ( + tempfile.TemporaryDirectory() as data_gym_cache + ): # attempt to fixe loading issue os.environ["DATA_GYM_CACHE_DIR"] = str(data_gym_cache) whisper_model = self._load_whisper(self.config.whisper_model).cuda() wav = wav.cpu() diff --git a/stopes/modules/tests/test_split_merge_langs.py b/stopes/modules/tests/test_split_merge_langs.py index 19f47cb..60ccac8 100644 --- a/stopes/modules/tests/test_split_merge_langs.py +++ b/stopes/modules/tests/test_split_merge_langs.py @@ -42,9 +42,10 @@ async def test_split_with_meta(tmp_path: Path): ) input_shards.append(text_name) input_metas.append(meta_name) - with gzip.open(text_name, mode="wt") as f_text, gzip.open( - meta_name, mode="wt" - ) as f_meta: + with ( + gzip.open(text_name, mode="wt") as f_text, + gzip.open(meta_name, mode="wt") as f_meta, + ): for line_id in range(shard_size): print("text", i, line_id, file=f_text) print("meta", i, line_id, file=f_meta) @@ -164,9 +165,10 @@ async def test_merge_bitext(tmp_path: Path): tmp_path / f"bimeta_{j}.tsv.gz", ) inputs.append((text_name, meta_name)) - with utils.open(text_name, mode="wt") as f_text, utils.open( - meta_name, mode="wt" - ) as f_meta: + with ( + utils.open(text_name, mode="wt") as f_text, + utils.open(meta_name, mode="wt") as f_meta, + ): unique_texts = [ (f"unique_text_en_{j}_{i}", f"unique_text_fr_{j}_{i}") for i in range(50) diff --git a/stopes/pipelines/bitext/global_mining_pipeline.py b/stopes/pipelines/bitext/global_mining_pipeline.py index dea9ce1..ddc1017 100644 --- a/stopes/pipelines/bitext/global_mining_pipeline.py +++ b/stopes/pipelines/bitext/global_mining_pipeline.py @@ -75,7 +75,9 @@ class Lang: # This representation of a language is used within the pipeline. # It is inherited from the config, but may change if the language is split. lang_name: str # original language name (e.g. eng) - split_name: str # language split name (can be different for big languages, e.g. eng001) + split_name: ( + str # language split name (can be different for big languages, e.g. eng001) + ) data_shards: tp.List[str] meta_shards: tp.Optional[tp.List[str]] shard_sizes: tp.List[int] @@ -604,9 +606,9 @@ async def _split_lang( lang_name=lng.lang_name, split_name=f"{lng.lang_name}_{i:03d}", data_shards=[str(text) for text in texts], - meta_shards=[str(meta) for meta in metas] - if metas is not None - else None, + meta_shards=( + [str(meta) for meta in metas] if metas is not None else None + ), shard_sizes=sizes, ) for i, (texts, sizes, metas) in enumerate( diff --git a/stopes/pipelines/distillation/distillation_bitext_processor.py b/stopes/pipelines/distillation/distillation_bitext_processor.py index 8318be1..ae91485 100644 --- a/stopes/pipelines/distillation/distillation_bitext_processor.py +++ b/stopes/pipelines/distillation/distillation_bitext_processor.py @@ -259,7 +259,7 @@ def process_lines( and write them to the output file """ # split sentences - for (line_id, dataset_line) in enumerate(dataset_reader): + for line_id, dataset_line in enumerate(dataset_reader): (real_tgt_line, _tgt_metadata) = extract_distillation_metadata( dataset_line.tgt ) diff --git a/stopes/pipelines/distillation/distillation_pipeline.py b/stopes/pipelines/distillation/distillation_pipeline.py index 1fd1289..abc7b8c 100644 --- a/stopes/pipelines/distillation/distillation_pipeline.py +++ b/stopes/pipelines/distillation/distillation_pipeline.py @@ -95,7 +95,7 @@ async def bitext_clean_helper( ) with utils.clone_config(config) as bitext_config: bitext_config.custom_name = f"bitext_clean_{lang_pair}.{tgt_lang}" - bitext_config.shards = [([str(file) for file in file_pair])] + bitext_config.shards = [[str(file) for file in file_pair]] bitext_config.requirements = Requirements(**config.requirements) bitext_config.bitext_processor._target_ = ( f"{BitextSplitNormalizeFilterLID.__module__}.BitextSplitNormalizeFilterLID" diff --git a/stopes/pipelines/eval/eval_blaser.py b/stopes/pipelines/eval/eval_blaser.py index d1c33ad..22c7c2c 100644 --- a/stopes/pipelines/eval/eval_blaser.py +++ b/stopes/pipelines/eval/eval_blaser.py @@ -102,19 +102,21 @@ async def eval_blaser( ) ) ), - launcher.schedule( - ComputeEmbedding( - ComputeEmbeddingConfig( - checkpoint_file=ref_enc, - manifest_file=config.ref_manifest, - out_file=emb_out_dir / f"reference-{config.ref_lang}-emb.npy", - checkpoint_dir=config.checkpoint_dir, - max_tokens=config.max_tokens, + ( + launcher.schedule( + ComputeEmbedding( + ComputeEmbeddingConfig( + checkpoint_file=ref_enc, + manifest_file=config.ref_manifest, + out_file=emb_out_dir / f"reference-{config.ref_lang}-emb.npy", + checkpoint_dir=config.checkpoint_dir, + max_tokens=config.max_tokens, + ) ) ) - ) - if config.ref_manifest - else None, + if config.ref_manifest + else None + ), ) # 2. call blaser diff --git a/stopes/pipelines/monolingual/monolingual_line_processor.py b/stopes/pipelines/monolingual/monolingual_line_processor.py index 70bb5f9..0ba314f 100644 --- a/stopes/pipelines/monolingual/monolingual_line_processor.py +++ b/stopes/pipelines/monolingual/monolingual_line_processor.py @@ -242,9 +242,9 @@ def __init__( thresholds_file = getattr(lid_config, "thresholds_file", None) self.lid_predictor = get_lid_predictor( model_file=Path(lid_config.model_file), - thresholds_file=Path(thresholds_file) - if thresholds_file is not None - else None, + thresholds_file=( + Path(thresholds_file) if thresholds_file is not None else None + ), label_unk=lid_config.label_unk, ) else: @@ -406,7 +406,7 @@ def process_lines(self, lines_with_number: tp.Iterator[tp.Tuple[int, str]]) -> N and write them to the output file """ # split sentences - for (line_id, line) in lines_with_number: + for line_id, line in lines_with_number: (real_line, _metadata) = extract_metadata(line, self.corpus) # we throw away metadata, use corpus+offset+linenumber to rebuild it self.result_summary.paragraphs += 1 diff --git a/stopes/pipelines/monolingual/utils/predict_script.py b/stopes/pipelines/monolingual/utils/predict_script.py index 94680cf..d468775 100644 --- a/stopes/pipelines/monolingual/utils/predict_script.py +++ b/stopes/pipelines/monolingual/utils/predict_script.py @@ -202,7 +202,10 @@ def test_predict_script(): assert predictor_fn( "자미로콰이 Jamiroquai는 영국의 애시드 재즈 밴드이다 자미로콰이는 년대 초반 런던에서 활발하게 일어난 애시드 재즈" ) == ("Hang", 0.8148148148148148) - assert predictor_fn("이어지는기사 에서그점 에관해알려줄것 입니다") == ("Hang", 1.0) + assert predictor_fn("이어지는기사 에서그점 에관해알려줄것 입니다") == ( + "Hang", + 1.0, + ) # not sure about this behaviour assert predictor_fn("এ 1234 b") == (None, 0) diff --git a/stopes/pipelines/monolingual/utils/sentence_split.py b/stopes/pipelines/monolingual/utils/sentence_split.py index ebea88c..e66337b 100644 --- a/stopes/pipelines/monolingual/utils/sentence_split.py +++ b/stopes/pipelines/monolingual/utils/sentence_split.py @@ -106,7 +106,7 @@ def map_lang(lang: str, equivalence_file: Path) -> str: def make_splitter_caseless( - base_splitter: tp.Callable[[str], tp.Iterable[str]] + base_splitter: tp.Callable[[str], tp.Iterable[str]], ) -> tp.Callable[[str], tp.Iterable[str]]: """ Try splitting an uppercase version of the texts. diff --git a/stopes/pipelines/prepare_data/dedup_sharding.py b/stopes/pipelines/prepare_data/dedup_sharding.py index e52e1c7..6e417e5 100644 --- a/stopes/pipelines/prepare_data/dedup_sharding.py +++ b/stopes/pipelines/prepare_data/dedup_sharding.py @@ -121,9 +121,10 @@ def run( if dedup_sharding_job.eval_datasets: for eval_dataset in dedup_sharding_job.eval_datasets: - with utils.open(eval_dataset.src, "rt") as s_f, utils.open( - eval_dataset.tgt, "rt" - ) as t_f: + with ( + utils.open(eval_dataset.src, "rt") as s_f, + utils.open(eval_dataset.tgt, "rt") as t_f, + ): for src_line, tgt_line in zip(s_f, t_f): self._already_seen(src_line, tgt_line, DedupType.both) @@ -164,9 +165,9 @@ def run( Dataset( src=str(src_outfile), tgt=str(tgt_outfile), - metadata=str(metadata_outfile) - if metadata_outfile - else None, + metadata=( + str(metadata_outfile) if metadata_outfile else None + ), lang_dir=lang_dir, fold=train_dataset.fold, ) @@ -175,13 +176,15 @@ def run( random.seed(0) seen_lines = 0 num_lines = 0 - with utils.open(train_dataset.src, "rt") as s_f, utils.open( - train_dataset.tgt, "rt" - ) as t_f, utils.open( - train_dataset.metadata, "rt" - ) if train_dataset.metadata else contextlib.nullcontext( - itertools.repeat(None) - ) as m_f: + with ( + utils.open(train_dataset.src, "rt") as s_f, + utils.open(train_dataset.tgt, "rt") as t_f, + ( + utils.open(train_dataset.metadata, "rt") + if train_dataset.metadata + else contextlib.nullcontext(itertools.repeat(None)) + ) as m_f, + ): for src_line, tgt_line, metadata_line in zip(s_f, t_f, m_f): shard_id = random.randint(0, num_shards - 1) if not self._already_seen( diff --git a/stopes/pipelines/prepare_data/validate.py b/stopes/pipelines/prepare_data/validate.py index 6fd6fa3..9b5c664 100644 --- a/stopes/pipelines/prepare_data/validate.py +++ b/stopes/pipelines/prepare_data/validate.py @@ -82,7 +82,7 @@ async def validate( train_tgt_counts_map = defaultdict(int) train_counts_map = defaultdict(int) - for (num_lines, dataset) in validation_results: + for num_lines, dataset in validation_results: if dataset.fold.startswith("train"): src, tgt = dataset.lang_dir.split("-") train_src_counts_map[src] += num_lines diff --git a/stopes/utils/cache.py b/stopes/utils/cache.py index e01b6bc..de204c9 100644 --- a/stopes/utils/cache.py +++ b/stopes/utils/cache.py @@ -105,7 +105,7 @@ def handle_batch(cache: GenerationalCache[K, V]) -> tp.Iterator[V]: for x, res in zip(batch, new_results): cache[x] = res j = 0 - for (x, maybe_res) in batch_with_results: + for x, maybe_res in batch_with_results: if maybe_res is None: maybe_res = new_results[j] j += 1 diff --git a/stopes/utils/parquet_dataloader.py b/stopes/utils/parquet_dataloader.py index 4bbff9f..63f3231 100644 --- a/stopes/utils/parquet_dataloader.py +++ b/stopes/utils/parquet_dataloader.py @@ -223,12 +223,14 @@ def epoch_iterator( ) np_rs = np.random.RandomState(seed) - with Parallel( - n_jobs=max(nb_cpu // 2, 1), backend="threading", return_as="generator" - ) as parallel_outer, Parallel( - n_jobs=nb_cpu, backend="threading", return_as="generator" - ) as parallel_inner, pyarrow_cpu( - nb_cpu + with ( + Parallel( + n_jobs=max(nb_cpu // 2, 1), backend="threading", return_as="generator" + ) as parallel_outer, + Parallel( + n_jobs=nb_cpu, backend="threading", return_as="generator" + ) as parallel_inner, + pyarrow_cpu(nb_cpu), ): if order_by_length is not None: columns = sorted( diff --git a/stopes/utils/tts_preprocessing/cleaners.py b/stopes/utils/tts_preprocessing/cleaners.py index 91fdd07..937da70 100644 --- a/stopes/utils/tts_preprocessing/cleaners.py +++ b/stopes/utils/tts_preprocessing/cleaners.py @@ -20,7 +20,9 @@ from stopes.utils.tts_preprocessing.numbers import ( SUPPORTED_LANGS as NUMEXP_SUPPORTED_LANGS, ) -from stopes.utils.tts_preprocessing.numbers import expand_numbers +from stopes.utils.tts_preprocessing.numbers import ( + expand_numbers, +) logger = logging.getLogger(__name__) From 50a9d85e9bdb077cbfb5ce46506d0287a65882a2 Mon Sep 17 00:00:00 2001 From: Alexander Erben Date: Tue, 14 Oct 2025 12:06:02 -0400 Subject: [PATCH 3/5] Fix async error in CI --- stopes/pipelines/bitext/global_mining_pipeline.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stopes/pipelines/bitext/global_mining_pipeline.py b/stopes/pipelines/bitext/global_mining_pipeline.py index ddc1017..159bb27 100644 --- a/stopes/pipelines/bitext/global_mining_pipeline.py +++ b/stopes/pipelines/bitext/global_mining_pipeline.py @@ -419,10 +419,9 @@ async def _process_language_shard( return result def run(self) -> tp.Tuple[Path, Path]: - loop = asyncio.get_event_loop() if self.config.launcher.cluster == "debug": - loop.set_debug(True) - return loop.run_until_complete(self.arun()) + asyncio.get_event_loop().set_debug(True) + return asyncio.run(self.arun()) async def arun(self) -> tp.Tuple[Path, Path]: """Run the global mining pipeline and return the paths of the mined text and metadata files""" From 4d5b8350cda5bd2c6e50b9a84990275f6f5950da Mon Sep 17 00:00:00 2001 From: Alexander Erben Date: Tue, 14 Oct 2025 12:59:01 -0400 Subject: [PATCH 4/5] replacing mnist-text for rotten-tomatoes in test-case due to hf scripts deprecation --- stopes/pipelines/tests/test_global_mining.py | 1 - stopes/utils/sharding/hf_shards.py | 1 - stopes/utils/test_hf_shards.py | 29 ++++++++++++-------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/stopes/pipelines/tests/test_global_mining.py b/stopes/pipelines/tests/test_global_mining.py index c427283..70bf0d5 100644 --- a/stopes/pipelines/tests/test_global_mining.py +++ b/stopes/pipelines/tests/test_global_mining.py @@ -129,7 +129,6 @@ def encode_to_np( (False, True, True, "speech"), ], ) -@pytest.mark.asyncio(scope="session") def test_global_mining_pipeline( tmp_path: Path, split_langs: bool, use_meta: bool, fp16: bool, modality: str ) -> None: diff --git a/stopes/utils/sharding/hf_shards.py b/stopes/utils/sharding/hf_shards.py index 53d7cfc..02e5bc8 100644 --- a/stopes/utils/sharding/hf_shards.py +++ b/stopes/utils/sharding/hf_shards.py @@ -103,7 +103,6 @@ def __enter__(self): cache_dir=_cache_dir, download_mode=_download_mode, split=self.split, - trust_remote_code=self.trust_remote_code, ) if self.split is None: # _data is a DatasetDict, convert to Dataset _data = concatenate_datasets( diff --git a/stopes/utils/test_hf_shards.py b/stopes/utils/test_hf_shards.py index aaa080d..8ce5ac3 100644 --- a/stopes/utils/test_hf_shards.py +++ b/stopes/utils/test_hf_shards.py @@ -8,39 +8,44 @@ from stopes.utils.sharding.hf_shards import HFInputConfig, HFShard # TODO: Hard code this to test if there are changes in HF datasets API -first_item_id = 7 +expected_first_four = [ + 1, + 0, + 1, + 0, +] # contemmcm/rotten_tomatoes first 4 reviewState values def test_shard_iteration(): shard = HFShard( filter=None, - path_or_name="Fraser/mnist-text-small", - split="test", + path_or_name="contemmcm/rotten_tomatoes", + split="complete", index=0, num_shards=50, - trust_remote_code=True, ) with shard: item = next(iter(shard)) assert isinstance(item, dict) - assert "label" in item - assert item["label"] == first_item_id + assert "reviewState" in item + assert item["reviewState"] == expected_first_four[0] with shard as progress: batch_iter = progress.to_batches(batch_size=4) - item = next(batch_iter) - assert item["label"][0].as_py() == first_item_id # type: ignore + batch = next(batch_iter) + # Verify first 4 items match expected pattern [1,0,1,0] + for i in range(4): + assert batch["reviewState"][i].as_py() == expected_first_four[i] # type: ignore def test_input_config(): input_config = HFInputConfig( - input_file="Fraser/mnist-text-small", - split="test", + input_file="contemmcm/rotten_tomatoes", + split="complete", num_shards=50, - trust_remote_code=True, ) shards = input_config.make_shards() first_shard = shards[0] with first_shard: item = next(iter(first_shard)) - assert item["label"] == first_item_id + assert item["reviewState"] == expected_first_four[0] From 30e84197e30d66daf2ffcc2c57b38e22fbf6d794 Mon Sep 17 00:00:00 2001 From: Alexander Erben Date: Tue, 14 Oct 2025 15:28:48 -0400 Subject: [PATCH 5/5] removed black version from workflow --- .github/workflows/lint_and_tests.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/lint_and_tests.yaml b/.github/workflows/lint_and_tests.yaml index 6a11f02..5c1c5e8 100644 --- a/.github/workflows/lint_and_tests.yaml +++ b/.github/workflows/lint_and_tests.yaml @@ -41,8 +41,6 @@ jobs: - name: isort run: isort --check --diff . - - name: black version - run: black --version - name: black run: black --check --diff . - name: pytest