diff --git a/pyproject.toml b/pyproject.toml index 1a0e93f..23ec357 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,8 +55,10 @@ classifiers=[ "faiss-gpu", ] speech = [ + "fairseq2", "hanziconv", "inflect", + "seamless_communication", "tnkeeh", "torchaudio", "num2words", diff --git a/stopes/pipelines/asr_bleu/__init__.py b/stopes/pipelines/asr_bleu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stopes/pipelines/asr_bleu/asr_bleu.py b/stopes/pipelines/asr_bleu/asr_bleu.py new file mode 100644 index 0000000..0a587ff --- /dev/null +++ b/stopes/pipelines/asr_bleu/asr_bleu.py @@ -0,0 +1,53 @@ +import asyncio +import logging +from pathlib import Path + +import hydra +from omegaconf import OmegaConf + +from stopes.core import utils +from stopes.pipelines.asr_bleu.compute_asr_bleu import compute_asr_bleu +from stopes.pipelines.asr_bleu.configs import AsrBleuConfig + +logger = logging.getLogger("asr_bleu") + + +class AsrBleu: + def __init__(self, config: AsrBleuConfig): + self.config = config + self.ensure_all_dirs() + self.launcher = hydra.utils.instantiate(self.config.launcher) + self.config.launcher.cache.caching_dir = Path(self.output_dir) / "cache" + OmegaConf.save( + config=config, + f=str(self.output_dir / "asr_bleu.yaml"), + ) + + OmegaConf.set_readonly(self.config, True) + + async def run(self): + logger.info("Computing ASRBleu on selected datasets...") + await compute_asr_bleu( + self.config.output_dir, + self.config.split, + self.config.model_name, + self.config.eval_first_pass, + self.config.dataset, + self.config.audio_format, + self.config.datasets, + self.launcher, + ) + + def ensure_all_dirs(self) -> None: + self.output_dir = Path(self.config.output_dir).resolve() + utils.ensure_dir(self.output_dir) + + +@hydra.main(config_path="conf", config_name="asr_bleu") +def main(config: AsrBleuConfig) -> None: + pipeline = AsrBleu(config) + asyncio.run(pipeline.run()) + + +if __name__ == "__main__": + main() diff --git a/stopes/pipelines/asr_bleu/compute_asr_bleu.py b/stopes/pipelines/asr_bleu/compute_asr_bleu.py new file mode 100644 index 0000000..10a1a20 --- /dev/null +++ b/stopes/pipelines/asr_bleu/compute_asr_bleu.py @@ -0,0 +1,97 @@ +import logging +import typing as tp +from dataclasses import dataclass + +from m4t_scripts.evaluate.asr_bleu import ASRBleu +from omegaconf.omegaconf import MISSING + +from stopes.core.launcher import Launcher +from stopes.core.stopes_module import Requirements, StopesModule +from stopes.pipelines.asr_bleu.configs import Dataset + + +@dataclass +class ComputeASRBleuJob: + lang_dir: str = MISSING + split: str = MISSING + num_data_pairs: int = MISSING + model_name: str = MISSING + eval_first_pass: bool = MISSING + dataset: str = MISSING + audio_format: str = MISSING + + +@dataclass +class ComputeASRBleuConfig: + compute_asrbleu_jobs: tp.List[ComputeASRBleuJob] = MISSING + output_dir: str = MISSING + + +class ComputeASRBleu(StopesModule): + def __init__(self, config: ComputeASRBleuConfig): + super().__init__(config=config, config_class=ComputeASRBleuConfig) + self.asrbleu = ASRBleu(config.output_dir) + self.logger = logging.getLogger("stopes.asr_bleu") + + def array(self): + return self.config.compute_asrbleu_jobs + + def requirements(self) -> Requirements: + return Requirements( + nodes=1, + tasks_per_node=1, + gpus_per_node=1, + cpus_per_task=1, + timeout_min=24 * 60, + ) + + def run( + self, + iteration_value: tp.Optional[tp.Any] = None, + iteration_index: int = 0, + ): + """Runs compute_asr_bleu for each ComputeASRBleuJob""" + assert iteration_value is not None, "iteration value is null" + self.logger.info(f"Running compute_asr_bleu on {iteration_value.lang_dir}") + self.asrbleu.compute_asr_bleu( + iteration_value.lang_dir, + iteration_value.split, + iteration_value.num_data_pairs, + iteration_value.model_name, + iteration_value.eval_first_pass, + iteration_value.dataset, + iteration_value.audio_format, + ) + + +async def compute_asr_bleu( + output_dir: str, + split: str, + model_name: str, + eval_first_pass: bool, + dataset_name: str, + audio_format: str, + datasets: tp.Dict[str, Dataset], + launcher: Launcher, +) -> tp.List[tp.Tuple[tp.Dict[str, tp.List], str, str]]: + """ + Compute ASRBleu on specified datasets + """ + compute_asrbleu_jobs = [ + ComputeASRBleuJob( + lang_dir=datasets[dataset].lang_dir, + split=split, + num_data_pairs=datasets[dataset].num_data_pairs, + model_name=model_name, + eval_first_pass=eval_first_pass, + dataset=dataset_name, + audio_format=audio_format, + ) + for dataset in datasets + ] + compute_asrbleu_module = ComputeASRBleu( + ComputeASRBleuConfig( + compute_asrbleu_jobs=compute_asrbleu_jobs, output_dir=output_dir + ) + ) + await launcher.schedule(compute_asrbleu_module) diff --git a/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml b/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml new file mode 100644 index 0000000..2bd1cad --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/asr_bleu.yaml @@ -0,0 +1,17 @@ +defaults: + - launcher: local + - _self_ + +output_dir: ??? +split: "test" +model_name: "seamlessM4T_medium" +eval_first_pass: True +dataset: "fleurs" +audio_format: "n_pred.wav" + +launcher: + partition: ??? # set as null if running locally + cache: + caching_dir: ${output_dir}/cache + +datasets: ??? \ No newline at end of file diff --git a/stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml b/stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml new file mode 100644 index 0000000..264f635 --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/launcher/cache/file_cache.yaml @@ -0,0 +1,2 @@ +_target_: stopes.core.FileCache +caching_dir: /tmp/stopes_cache diff --git a/stopes/pipelines/asr_bleu/conf/launcher/local.yaml b/stopes/pipelines/asr_bleu/conf/launcher/local.yaml new file mode 100644 index 0000000..d10ebdd --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/launcher/local.yaml @@ -0,0 +1,8 @@ +defaults: + - cache: file_cache + +_target_: stopes.core.Launcher +log_folder: executor_logs +cluster: local +partition: null +max_jobarray_jobs: 1000 \ No newline at end of file diff --git a/stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml b/stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml new file mode 100644 index 0000000..0e1b0a2 --- /dev/null +++ b/stopes/pipelines/asr_bleu/conf/launcher/submitit.yaml @@ -0,0 +1,8 @@ +defaults: + - cache: file_cache + +_target_: stopes.core.Launcher +log_folder: executor_logs +cluster: slurm +partition: null +max_jobarray_jobs: 1000 \ No newline at end of file diff --git a/stopes/pipelines/asr_bleu/configs.py b/stopes/pipelines/asr_bleu/configs.py new file mode 100644 index 0000000..a4f85d5 --- /dev/null +++ b/stopes/pipelines/asr_bleu/configs.py @@ -0,0 +1,20 @@ +import typing as tp +from dataclasses import dataclass + + +@dataclass +class Dataset: + lang_dir: str + num_data_pairs: int + + +@dataclass +class AsrBleuConfig: + output_dir: str + split: str + model_name: str + eval_first_pass: bool + dataset: str + audio_format: str + launcher: tp.Dict[str, tp.Any] + datasets: tp.Dict[str, Dataset]