From 250d6a33e5d0acac7293f3555415ad13ea2f9e78 Mon Sep 17 00:00:00 2001 From: Ben King Date: Tue, 27 Jan 2026 01:41:19 +0000 Subject: [PATCH 1/4] Initial implementation of multilingual inference (just for translate step so far) --- silnlp/common/translator.py | 22 ++++++++++----- silnlp/nmt/config.py | 8 +++--- silnlp/nmt/hugging_face_config.py | 45 ++++++++++++++++++++++++------- silnlp/nmt/translate.py | 8 +++--- 4 files changed, 59 insertions(+), 24 deletions(-) diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 51c32879..04ea1061 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -265,11 +265,11 @@ class Translator(AbstractContextManager["Translator"], ABC): @abstractmethod def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[str], + src_isos: List[str], + trg_isos: List[str], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, + vrefs: Optional[List[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: pass @@ -286,8 +286,10 @@ def translate_text( ) -> None: sentences = [add_tags_to_sentence(tags, sentence) for sentence in load_corpus(src_file_path)] + src_isos = [src_iso for s in sentences] + trg_isos = [trg_iso for s in sentences] sentence_translation_groups: List[SentenceTranslationGroup] = list( - self.translate(sentences, src_iso, trg_iso, produce_multiple_translations) + self.translate(sentences, src_isos, trg_isos, produce_multiple_translations) ) draft_set = DraftGroup(sentence_translation_groups) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): @@ -387,6 +389,8 @@ def translate_usfm( src_file_text = UsfmFileText(stylesheet, "utf-8-sig", book_id, src_file_path, include_all_text=True) sentences = [re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip())) for s in src_file_text] + src_isos = [src_iso for s in src_file_text] + trg_isos = [trg_iso for s in src_file_text] scripture_refs: List[ScriptureRef] = [s.ref for s in src_file_text] vrefs: List[VerseRef] = [sr.verse_ref for sr in scripture_refs] LOGGER.info(f"File {src_file_path} parsed correctly.") @@ -408,7 +412,7 @@ def translate_usfm( empty_sents.append((i, scripture_refs.pop(i))) sentence_translation_groups: List[SentenceTranslationGroup] = list( - self.translate(sentences, src_iso, trg_iso, produce_multiple_translations, vrefs) + self.translate(sentences, src_isos, trg_isos, produce_multiple_translations, vrefs) ) num_drafts = len(sentence_translation_groups[0]) @@ -539,14 +543,18 @@ def translate_docx( sentences: List[str] = [] paras: List[int] = [] + src_isos: List[str] = [] + trg_isos: List[str] = [] for i, paragraph in enumerate(doc.paragraphs): for sentence in tokenizer.tokenize(paragraph.text): sentences.append(add_tags_to_sentence(tags, sentence)) + src_isos.append(src_iso) + trg_isos.append(trg_iso) paras.append(i) draft_set: DraftGroup = DraftGroup( - list(self.translate(sentences, src_iso, trg_iso, produce_multiple_translations)) + list(self.translate(sentences, src_isos, trg_isos, produce_multiple_translations)) ) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index d2b24538..7a829f6a 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -81,11 +81,11 @@ def translate_test_files( @abstractmethod def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[str], + src_isos: List[str], + trg_isos: List[str], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, + vrefs: Optional[List[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Generator[SentenceTranslationGroup, None, None]: ... diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 0ab5af1a..60708ebb 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -1260,21 +1260,23 @@ def get_num_drafts(self) -> int: def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[str], + src_isos: List[str], + trg_isos: List[str], produce_multiple_translations: bool = False, vrefs: Optional[Iterable[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Generator[SentenceTranslationGroup, None, None]: - src_lang = self._config.data["lang_codes"].get(src_iso, src_iso) - trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso) - inference_model_params = InferenceModelParams(ckpt, src_lang, trg_lang) + src_langs = [self._config.data["lang_codes"].get(src_iso, src_iso) for src_iso in src_isos] + trg_langs = [self._config.data["lang_codes"].get(trg_iso, trg_iso) for trg_iso in trg_isos] + inference_model_params = InferenceModelParams(ckpt, src_langs[0], trg_langs[0]) tokenizer = self._config.get_tokenizer() if self._inference_model_params == inference_model_params and self._cached_inference_model is not None: model = self._cached_inference_model else: - model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang) + model = self._cached_inference_model = self._create_inference_model( + ckpt, tokenizer, src_langs[0], trg_langs[0] + ) self._inference_model_params = inference_model_params if model.config.max_length is not None and model.config.max_length < 512: model.config.max_length = 512 @@ -1287,8 +1289,8 @@ def translate( pipeline = SilTranslationPipeline( model=model, tokenizer=tokenizer, - src_lang=src_lang, - tgt_lang=trg_lang, + src_langs=src_langs, + tgt_langs=trg_langs, device=0, ) @@ -1940,6 +1942,21 @@ def normalize(self, line: NormalizedString) -> None: class SilTranslationPipeline(TranslationPipeline): + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + src_langs: List[str], + tgt_langs: List[str], + device: int, + decoder_lang_code_tokens: List[int] | None = None, + ): + super().__init__(model=model, tokenizer=tokenizer, src_lang=src_langs[0], tgt_lang=tgt_langs[0], device=device) + self.tgt_langs = np.array( + [tokenizer.convert_tokens_to_ids(trg_lang) for trg_lang in tgt_langs], dtype=np.float32 + ) + self.tgt_index = 0 + def _forward(self, model_inputs, **generate_kwargs): in_b, input_length = model_inputs["input_ids"].shape @@ -1950,6 +1967,16 @@ def _forward(self, model_inputs, **generate_kwargs): generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length) generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) + generate_kwargs["decoder_input_ids"] = torch.cat( + ( + torch.ones((in_b, 1), dtype=torch.long, device=model_inputs["input_ids"].device) * 2, + torch.unsqueeze(torch.from_numpy(self.tgt_langs[self.tgt_index : self.tgt_index + in_b]), 0).to( + model_inputs["input_ids"].device + ), + ), + dim=1, + ) + self.tgt_index += in_b output = self.model.generate( **model_inputs, **generate_kwargs, diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index 2d936616..7866397f 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -27,14 +27,14 @@ def __init__(self, model: NMTModel, checkpoint: Union[CheckpointType, str, int]) def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[str], + src_isos: List[str], + trg_isos: List[str], produce_multiple_translations: bool = False, vrefs: Optional[Iterable[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: yield from self._model.translate( - sentences, src_iso, trg_iso, produce_multiple_translations, vrefs, self._checkpoint + sentences, src_isos, trg_isos, produce_multiple_translations, vrefs, self._checkpoint ) def __exit__( From b1a1780e9e99e816d1d9e7cff7cdd2a54cd94497 Mon Sep 17 00:00:00 2001 From: Ben King Date: Wed, 28 Jan 2026 17:51:38 +0000 Subject: [PATCH 2/4] Complete? support for translate step --- silnlp/common/translate_google.py | 13 +-- silnlp/common/translator.py | 77 +++++++------ silnlp/nmt/config.py | 8 +- silnlp/nmt/hugging_face_config.py | 186 ++++++++++++++++++------------ silnlp/nmt/test.py | 6 + silnlp/nmt/translate.py | 16 +-- 6 files changed, 174 insertions(+), 132 deletions(-) diff --git a/silnlp/common/translate_google.py b/silnlp/common/translate_google.py index 9cf3bd97..09772044 100644 --- a/silnlp/common/translate_google.py +++ b/silnlp/common/translate_google.py @@ -6,7 +6,7 @@ from machine.scripture import VerseRef, book_id_to_number from .paratext import book_file_name_digits -from .translator import SentenceTranslation, SentenceTranslationGroup, Translator +from .translator import SentenceTranslation, SentenceTranslationGroup, TranslationInputSentence, Translator from .utils import get_git_revision_hash, get_mt_exp_dir LOGGER = logging.getLogger((__package__ or "") + ".translate") @@ -18,21 +18,18 @@ def __init__(self) -> None: def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: Iterable[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: if produce_multiple_translations: LOGGER.warning("Google Translator does not support --multiple-translations") for sentence in sentences: - if len(sentence) == 0: - yield "" + if len(sentence.text) == 0: + yield [SentenceTranslation("", [], [], None)] else: results = self._translate_client.translate( - sentence, source_language=src_iso, target_language=trg_iso, format_="text" + sentence.text, source_language=sentence.src_iso, target_language=sentence.trg_iso, format_="text" ) translation: str = results["translatedText"] yield [SentenceTranslation(translation, [], [], None)] diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 04ea1061..19993833 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -3,11 +3,12 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import AbstractContextManager +from dataclasses import dataclass from datetime import date from itertools import groupby from math import exp from pathlib import Path -from typing import DefaultDict, Generator, Iterable, List, Optional, Tuple, cast +from typing import DefaultDict, Generator, List, Optional, Tuple import docx import nltk @@ -40,6 +41,15 @@ CONFIDENCE_SCORES_SUFFIX = ".confidences.tsv" +@dataclass +class TranslationInputSentence: + text: str + src_iso: str + trg_iso: str + scripture_ref: ScriptureRef | None = None + vref: VerseRef | None = None + + # A single translation of a single sentence class SentenceTranslation: def __init__( @@ -265,11 +275,8 @@ class Translator(AbstractContextManager["Translator"], ABC): @abstractmethod def translate( self, - sentences: List[str], - src_isos: List[str], - trg_isos: List[str], + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[List[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: pass @@ -285,11 +292,12 @@ def translate_text( tags: Optional[List[str]] = None, ) -> None: - sentences = [add_tags_to_sentence(tags, sentence) for sentence in load_corpus(src_file_path)] - src_isos = [src_iso for s in sentences] - trg_isos = [trg_iso for s in sentences] + translation_inputs = [ + TranslationInputSentence(add_tags_to_sentence(tags, sentence), src_iso, trg_iso) + for sentence in load_corpus(src_file_path) + ] sentence_translation_groups: List[SentenceTranslationGroup] = list( - self.translate(sentences, src_isos, trg_isos, produce_multiple_translations) + self.translate(translation_inputs, produce_multiple_translations) ) draft_set = DraftGroup(sentence_translation_groups) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): @@ -388,39 +396,38 @@ def translate_usfm( src_file_text = UsfmFileText(stylesheet, "utf-8-sig", book_id, src_file_path, include_all_text=True) - sentences = [re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip())) for s in src_file_text] - src_isos = [src_iso for s in src_file_text] - trg_isos = [trg_iso for s in src_file_text] - scripture_refs: List[ScriptureRef] = [s.ref for s in src_file_text] - vrefs: List[VerseRef] = [sr.verse_ref for sr in scripture_refs] + sentences = [ + TranslationInputSentence( + re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip())), src_iso, trg_iso, s.ref, s.ref.verse_ref + ) + for s in src_file_text + ] LOGGER.info(f"File {src_file_path} parsed correctly.") # Filter sentences for i in reversed(range(len(sentences))): - marker = scripture_refs[i].path[-1].name if len(scripture_refs[i].path) > 0 else "" + marker = sentences[i].scripture_ref.path[-1].name if len(sentences[i].scripture_ref.path) > 0 else "" if ( - (len(chapters) > 0 and scripture_refs[i].chapter_num not in chapters) + (len(chapters) > 0 and sentences[i].scripture_ref.chapter_num not in chapters) or marker in PARAGRAPH_TYPE_EMBEDS or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT ): sentences.pop(i) - scripture_refs.pop(i) empty_sents: List[Tuple[int, ScriptureRef]] = [] for i in reversed(range(len(sentences))): - if len(sentences[i].strip()) == 0: + if len(sentences[i].text.strip()) == 0: + empty_sents.append((i, sentences[i].scripture_ref)) sentences.pop(i) - empty_sents.append((i, scripture_refs.pop(i))) sentence_translation_groups: List[SentenceTranslationGroup] = list( - self.translate(sentences, src_isos, trg_isos, produce_multiple_translations, vrefs) + self.translate(sentences, produce_multiple_translations) ) num_drafts = len(sentence_translation_groups[0]) # Add empty sentences back in # Prevents pre-existing text from showing up in the sections of translated text for idx, vref in reversed(empty_sents): - sentences.insert(idx, "") - scripture_refs.insert(idx, vref) + sentences.insert(idx, TranslationInputSentence("", "", "", vref, vref.verse_ref)) sentence_translation_groups.insert(idx, [SentenceTranslation("", [], [], None)] * num_drafts) text_behavior = ( @@ -429,13 +436,17 @@ def translate_usfm( draft_set: DraftGroup = DraftGroup(sentence_translation_groups) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): - postprocess_handler.construct_rows(scripture_refs, sentences, translated_draft.get_all_translations()) + postprocess_handler.construct_rows( + [s.scripture_ref for s in sentences if s.scripture_ref is not None], + [s.text for s in sentences], + translated_draft.get_all_translations(), + ) for config in postprocess_handler.configs: # Compile draft remarks draft_src_str = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" - draft_remark = f"This draft of {scripture_refs[0].book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." + draft_remark = f"This draft of {sentences[0].scripture_ref.book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." postprocess_remark = config.get_postprocess_remark() remarks = [draft_remark] + ([postprocess_remark] if postprocess_remark else []) @@ -470,7 +481,7 @@ def translate_usfm( usfm = f.read() handler = UpdateUsfmParserHandler( rows=config.rows, - id_text=scripture_refs[0].book, + id_text=sentences[0].scripture_ref.book, text_behavior=text_behavior, paragraph_behavior=config.get_paragraph_behavior(), embed_behavior=config.get_embed_behavior(), @@ -517,7 +528,7 @@ def translate_usfm( translated_draft, trg_file_path, produce_multiple_translations=produce_multiple_translations, - scripture_refs=scripture_refs, + scripture_refs=[s.scripture_ref for s in sentences if s.scripture_ref is not None], draft_index=draft_index, ) @@ -541,21 +552,17 @@ def translate_docx( with src_file_path.open("rb") as file: doc = docx.Document(file) - sentences: List[str] = [] + translation_inputs: List[TranslationInputSentence] = [] paras: List[int] = [] - src_isos: List[str] = [] - trg_isos: List[str] = [] for i, paragraph in enumerate(doc.paragraphs): for sentence in tokenizer.tokenize(paragraph.text): - sentences.append(add_tags_to_sentence(tags, sentence)) - src_isos.append(src_iso) - trg_isos.append(trg_iso) + translation_inputs.append( + TranslationInputSentence(add_tags_to_sentence(tags, sentence), src_iso, trg_iso) + ) paras.append(i) - draft_set: DraftGroup = DraftGroup( - list(self.translate(sentences, src_isos, trg_isos, produce_multiple_translations)) - ) + draft_set: DraftGroup = DraftGroup(list(self.translate(translation_inputs, produce_multiple_translations))) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): for para, group in groupby(zip(translated_draft.get_all_translations(), paras), key=lambda t: t[1]): diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 7a829f6a..12785415 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -32,7 +32,7 @@ write_corpus, ) from ..common.environment import SIL_NLP_ENV -from ..common.translator import SentenceTranslationGroup +from ..common.translator import SentenceTranslationGroup, TranslationInputSentence from ..common.utils import NoiseMethod, Side, add_tags_to_dataframe, add_tags_to_sentence, get_mt_exp_dir, set_seed from .augment import AugmentMethod from .corpora import ( @@ -71,6 +71,7 @@ def save_effective_config(self, path: Path) -> None: ... def translate_test_files( self, input_paths: List[Path], + test_gold_standard_paths: List[Path], translation_paths: List[Path], produce_multiple_translations: bool = False, save_confidences: bool = False, @@ -81,11 +82,8 @@ def translate_test_files( @abstractmethod def translate( self, - sentences: List[str], - src_isos: List[str], - trg_isos: List[str], + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[List[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Generator[SentenceTranslationGroup, None, None]: ... diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 60708ebb..58c681b2 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -53,8 +53,8 @@ T5Tokenizer, T5TokenizerFast, TensorType, + Text2TextGenerationPipeline, TrainerCallback, - TranslationPipeline, set_seed, ) from transformers.convert_slow_tokenizer import convert_slow_tokenizer @@ -79,6 +79,7 @@ DraftGroup, SentenceTranslation, SentenceTranslationGroup, + TranslationInputSentence, generate_test_confidence_files, ) from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, get_mt_exp_dir, merge_dict @@ -815,20 +816,16 @@ def batch_prepare_for_model( return BatchEncoding(batch_outputs, tensor_type=return_tensors) -TSent = TypeVar("TSent") - - def batch_sentences( - sentences: Iterable[TSent], - vrefs: Optional[Iterable[VerseRef]], + sentences: Iterable[TranslationInputSentence], batch_size: int, dictionary: Dict[VerseRef, Set[str]], -) -> Iterable[Tuple[List[TSent], Optional[List[List[List[str]]]]]]: - batch: List[TSent] = [] - for sentence, vref in zip(sentences, repeat(None) if vrefs is None else vrefs): +) -> Iterable[Tuple[List[TranslationInputSentence], Optional[List[List[List[str]]]]]]: + batch: List[TranslationInputSentence] = [] + for sentence in sentences: terms: Set[str] = set() - if vref is not None: - for vr in vref.all_verses(): + if sentence.vref is not None: + for vr in sentence.vref.all_verses(): terms.update(dictionary.get(vr, set())) if len(terms) > 0: if len(batch) > 0: @@ -1173,6 +1170,7 @@ def save_effective_config(self, path: Path) -> None: def translate_test_files( self, input_paths: List[Path], + test_gold_standard_paths: List[Path], translation_paths: List[Path], produce_multiple_translations: bool = False, save_confidences: bool = False, @@ -1184,27 +1182,38 @@ def translate_test_files( pipeline = PretokenizedTranslationPipeline( model=model, tokenizer=tokenizer, - src_lang=self._config.test_src_lang, - tgt_lang=self._config.test_trg_lang, device=0, ) pipeline.model = torch.compile(pipeline.model) - for input_path, translation_path, vref_path in zip( + for input_path, test_gold_standard_path, translation_path, vref_path in zip( input_paths, + test_gold_standard_paths, translation_paths, cast(Iterable[Optional[Path]], repeat(None) if vref_paths is None else vref_paths), ): - length = count_lines(input_path) with ExitStack() as stack: src_file = stack.enter_context(input_path.open("r", encoding="utf-8-sig")) - sentences = (line.strip().split() for line in src_file) + sentences = [line.strip().split() for line in src_file] + src_isos = [sentence[0] for sentence in sentences] + sentences = [" ".join(sentence[1:]) for sentence in sentences] + + gold_trg_file = stack.enter_context(test_gold_standard_path.open("r", encoding="utf-8-sig")) + trg_isos = [line.strip().split()[0] for line in gold_trg_file] + vrefs: Optional[Iterable[VerseRef]] = None if vref_path is not None: vref_file = stack.enter_context(vref_path.open("r", encoding="utf-8-sig")) vrefs = (VerseRef.from_string(line.strip(), ORIGINAL_VERSIFICATION) for line in vref_file) + + translation_inputs = [ + TranslationInputSentence(src_iso, trg_iso, sentence, None, vref) + for src_iso, trg_iso, sentence, vref in zip( + src_isos, trg_isos, sentences, vrefs if vrefs is not None else repeat(None) + ) + ] sentence_translation_groups: List[SentenceTranslationGroup] = list( self._translate_test_sentences( - tokenizer, pipeline, sentences, vrefs, length, produce_multiple_translations + tokenizer, pipeline, translation_inputs, produce_multiple_translations ) ) draft_group = DraftGroup(sentence_translation_groups) @@ -1230,10 +1239,8 @@ def translate_test_files( def _translate_test_sentences( self, tokenizer: PreTrainedTokenizer, - pipeline: TranslationPipeline, - sentences: Iterable[List[str]], - vrefs: Iterable[VerseRef], - length: int, + pipeline: "SilTranslationPipeline", + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, ) -> Iterable[SentenceTranslationGroup]: num_drafts = self.get_num_drafts() @@ -1247,9 +1254,9 @@ def _translate_test_sentences( for model_output_group in tqdm( self._translate_sentences( - tokenizer, pipeline, sentences, vrefs, produce_multiple_translations, return_tensors=True + tokenizer, pipeline, sentences, produce_multiple_translations, return_tensors=True ), - total=length, + total=len(sentences), unit="ex", ): yield model_output_group.convert_to_sentence_translation_group(tokenizer) @@ -1260,15 +1267,12 @@ def get_num_drafts(self) -> int: def translate( self, - sentences: List[str], - src_isos: List[str], - trg_isos: List[str], + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Generator[SentenceTranslationGroup, None, None]: - src_langs = [self._config.data["lang_codes"].get(src_iso, src_iso) for src_iso in src_isos] - trg_langs = [self._config.data["lang_codes"].get(trg_iso, trg_iso) for trg_iso in trg_isos] + src_langs = [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences] + trg_langs = [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences] inference_model_params = InferenceModelParams(ckpt, src_langs[0], trg_langs[0]) tokenizer = self._config.get_tokenizer() if self._inference_model_params == inference_model_params and self._cached_inference_model is not None: @@ -1289,8 +1293,6 @@ def translate( pipeline = SilTranslationPipeline( model=model, tokenizer=tokenizer, - src_langs=src_langs, - tgt_langs=trg_langs, device=0, ) @@ -1307,7 +1309,7 @@ def translate( if not isinstance(sentences, list): sentences = list(sentences) for model_output_group in tqdm( - self._translate_sentences(tokenizer, pipeline, sentences, vrefs, produce_multiple_translations), + self._translate_sentences(tokenizer, pipeline, sentences, produce_multiple_translations), total=len(sentences), unit="ex", ): @@ -1512,16 +1514,15 @@ def _merge_and_delete_adapter(self, checkpoint_path: Path, vocab_size: int, save def _translate_sentences( self, tokenizer: PreTrainedTokenizer, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], - vrefs: Optional[Iterable[VerseRef]], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], produce_multiple_translations: bool = False, return_tensors: bool = False, ) -> Iterable[ModelOutputGroup]: batch_size: int = self._config.infer["infer_batch_size"] dictionary = self._get_dictionary() - if vrefs is None or len(dictionary) == 0: + if not any([s.vref for s in sentences]) is None or len(dictionary) == 0: yield from self._translate_sentence_helper( pipeline, sentences, @@ -1530,7 +1531,7 @@ def _translate_sentences( produce_multiple_translations=produce_multiple_translations, ) else: - for batch, force_words in batch_sentences(sentences, vrefs, batch_size, dictionary): + for batch, force_words in batch_sentences(sentences, batch_size, dictionary): if force_words is None: force_words_ids = None else: @@ -1548,8 +1549,8 @@ def _translate_sentences( def _translate_sentence_helper( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, force_words_ids: List[List[List[int]]] = None, @@ -1560,8 +1561,6 @@ def _translate_sentence_helper( if produce_multiple_translations and num_drafts > 1: multiple_translations_method: str = self._config.infer.get("multiple_translations_method") - sentences = list(sentences) - if multiple_translations_method == "hybrid": beam_search_results: List[dict] = self._translate_with_beam_search( pipeline, @@ -1649,8 +1648,8 @@ def _flatten_tokenized_translations(self, pipeline_output) -> List[dict]: def _translate_with_beam_search( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, num_return_sequences: int = 1, @@ -1660,8 +1659,12 @@ def _translate_with_beam_search( if num_beams is None: num_beams = self._config.params.get("generation_num_beams") + pipeline.update_source_and_target_languages( + [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences], + [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], + ) translations = pipeline( - sentences, + [s.text for s in sentences], num_beams=num_beams, num_return_sequences=num_return_sequences, force_words_ids=force_words_ids, @@ -1677,8 +1680,8 @@ def _translate_with_beam_search( def _translate_with_sampling( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, num_return_sequences: int = 1, @@ -1687,8 +1690,12 @@ def _translate_with_sampling( temperature: Optional[int] = self._config.infer.get("temperature") + pipeline.update_source_and_target_languages( + [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences], + [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], + ) translations = pipeline( - sentences, + [s.text for s in sentences], do_sample=True, temperature=temperature, num_return_sequences=num_return_sequences, @@ -1705,8 +1712,8 @@ def _translate_with_sampling( def _translate_with_diverse_beam_search( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, num_return_sequences: int = 1, @@ -1717,8 +1724,12 @@ def _translate_with_diverse_beam_search( num_beams = self._config.params.get("generation_num_beams") diversity_penalty: Optional[float] = self._config.infer.get("diversity_penalty") + pipeline.update_source_and_target_languages( + [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences], + [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], + ) translations = pipeline( - sentences, + [s.text for s in sentences], num_beams=num_beams, num_beam_groups=num_beams, num_return_sequences=num_return_sequences, @@ -1804,22 +1815,22 @@ def _configure_model( if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") - if ( + if ( src_lang != "" and trg_lang != "" and isinstance( tokenizer, (MBartTokenizer, MBartTokenizerFast, M2M100Tokenizer, NllbTokenizer, NllbTokenizerFast) ) - ): - tokenizer.src_lang = src_lang - tokenizer.tgt_lang = trg_lang + ): + tokenizer.src_lang = src_lang + tokenizer.tgt_lang = trg_lang - # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token - # as the first generated token. - forced_bos_token_id = tokenizer.convert_tokens_to_ids(trg_lang) - model.config.forced_bos_token_id = forced_bos_token_id - if model.generation_config is not None: - model.generation_config.forced_bos_token_id = forced_bos_token_id + # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token + # as the first generated token. + forced_bos_token_id = tokenizer.convert_tokens_to_ids(trg_lang) + model.config.forced_bos_token_id = forced_bos_token_id + if model.generation_config is not None: + model.generation_config.forced_bos_token_id = forced_bos_token_id return model, tokenizer @@ -1857,6 +1868,9 @@ def token_to_id(self, token: str) -> int: def decode(self, *args, **kwargs): return self._wrapped_tokenizer.decode(*args, **kwargs) + def set_src_lang(self, src_lang: str): + self._wrapped_tokenizer.src_lang = src_lang + class HuggingFaceTokenizer(Tokenizer): def __init__( @@ -1941,22 +1955,48 @@ def normalize(self, line: NormalizedString) -> None: self._tokenizer.normalize_normalized_string(line) -class SilTranslationPipeline(TranslationPipeline): - def __init__( - self, - model: PreTrainedModel, - tokenizer: PreTrainedTokenizer, - src_langs: List[str], - tgt_langs: List[str], - device: int, - decoder_lang_code_tokens: List[int] | None = None, - ): - super().__init__(model=model, tokenizer=tokenizer, src_lang=src_langs[0], tgt_lang=tgt_langs[0], device=device) +class SilTranslationPipeline(Text2TextGenerationPipeline): + def update_source_and_target_languages(self, src_langs: List[str], tgt_langs: List[str]) -> None: + self.src_langs = src_langs + self.src_index = 0 self.tgt_langs = np.array( - [tokenizer.convert_tokens_to_ids(trg_lang) for trg_lang in tgt_langs], dtype=np.float32 + [self.tokenizer.convert_tokens_to_ids(trg_lang) for trg_lang in tgt_langs], dtype=np.int64 ) self.tgt_index = 0 + def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None): + if getattr(self.tokenizer, "_build_translation_inputs", None): + return self.tokenizer._build_translation_inputs( + *args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang + ) + else: + return self._parse_and_tokenize(*args, truncation=truncation) + + def _parse_and_tokenize(self, *args, truncation): + prefix = self.prefix if self.prefix is not None else "" + if isinstance(args[0], list): # TODO: disallow this case + if self.tokenizer.pad_token_id is None: + raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input") + args = ([prefix + self.src_langs[self.src_index + i] + " " + arg for i, arg in enumerate(args[0])],) + self.tokenizer.src_lang = "" + self.src_index += len(args[0]) + padding = True + + elif isinstance(args[0], str): + args = (prefix + args[0],) + self.tokenizer.set_src_lang(self.src_langs[self.src_index]) + self.src_index += 1 + padding = False + else: + raise ValueError( + f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`" + ) + inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors=self.framework) + # This is produced by tokenizers but is an invalid generate kwargs + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + return inputs + def _forward(self, model_inputs, **generate_kwargs): in_b, input_length = model_inputs["input_ids"].shape @@ -1970,7 +2010,7 @@ def _forward(self, model_inputs, **generate_kwargs): generate_kwargs["decoder_input_ids"] = torch.cat( ( torch.ones((in_b, 1), dtype=torch.long, device=model_inputs["input_ids"].device) * 2, - torch.unsqueeze(torch.from_numpy(self.tgt_langs[self.tgt_index : self.tgt_index + in_b]), 0).to( + torch.unsqueeze(torch.from_numpy(self.tgt_langs[self.tgt_index : self.tgt_index + in_b]), 1).to( model_inputs["input_ids"].device ), ), diff --git a/silnlp/nmt/test.py b/silnlp/nmt/test.py index 74a66037..05aa80fe 100644 --- a/silnlp/nmt/test.py +++ b/silnlp/nmt/test.py @@ -478,6 +478,7 @@ def test_checkpoint( vref_file_names: List[str] = [] source_file_names: List[str] = [] translation_file_names: List[str] = [] + gold_standard_detok_file_names: List[str] = [] refs_patterns: List[str] = [] translation_detok_file_names: List[str] = [] translation_conf_file_names: List[str] = [] @@ -495,6 +496,7 @@ def test_checkpoint( refs_patterns.append("test.trg.detok*.txt") translation_detok_file_names.append(f"test.trg-predictions.detok.txt.{suffix_str}") translation_conf_file_names.append(f"test.trg-predictions.txt.{suffix_str}.confidences.tsv") + gold_standard_detok_file_names.append("test.trg.detok.txt") else: # test data is split into separate files for src_iso in sorted(config.test_src_isos): @@ -510,14 +512,17 @@ def test_checkpoint( refs_patterns.append(f"{prefix}.trg.detok*.txt") translation_detok_file_names.append(f"{prefix}.trg-predictions.detok.txt.{suffix_str}") translation_conf_file_names.append(f"{prefix}.trg-predictions.txt.{suffix_str}.confidences.tsv") + gold_standard_detok_file_names.append(f"{prefix}.trg.detok.txt") checkpoint_name = "averaged checkpoint" if step == -1 else f"checkpoint {step}" source_paths: List[Path] = [] + gold_standard_target_paths: List[Path] = [] vref_paths: Optional[List[Path]] = [] if config.has_scripture_data else None translation_paths: List[Path] = [] for i in range(len(translation_file_names)): predictions_path = config.exp_dir / translation_file_names[i] + gold_standard_target_paths.append(config.exp_dir / gold_standard_detok_file_names[i]) if force_infer or not predictions_path.is_file(): source_paths.append(config.exp_dir / source_file_names[i]) translation_paths.append(predictions_path) @@ -527,6 +532,7 @@ def test_checkpoint( LOGGER.info(f"Inferencing {checkpoint_name}") model.translate_test_files( source_paths, + gold_standard_target_paths, translation_paths, produce_multiple_translations, save_confidences, diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index 7866397f..48eb41d2 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -2,17 +2,16 @@ import logging import os import time -from contextlib import AbstractContextManager from dataclasses import dataclass from pathlib import Path -from typing import Generator, Iterable, List, Optional, Tuple, Union +from typing import Generator, List, Optional, Tuple, Union -from machine.scripture import VerseRef, book_number_to_id, get_chapters +from machine.scripture import book_number_to_id, get_chapters from ..common.environment import SIL_NLP_ENV from ..common.paratext import book_file_name_digits, get_project_dir from ..common.postprocesser import PostprocessConfig, PostprocessHandler -from ..common.translator import SentenceTranslationGroup, Translator +from ..common.translator import SentenceTranslationGroup, TranslationInputSentence, Translator from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import TAGS_LIST, SILClearML from .config import CheckpointType, Config, NMTModel @@ -27,15 +26,10 @@ def __init__(self, model: NMTModel, checkpoint: Union[CheckpointType, str, int]) def translate( self, - sentences: List[str], - src_isos: List[str], - trg_isos: List[str], + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: - yield from self._model.translate( - sentences, src_isos, trg_isos, produce_multiple_translations, vrefs, self._checkpoint - ) + yield from self._model.translate(sentences, produce_multiple_translations, self._checkpoint) def __exit__( self, exc_type, exc_value, traceback # pyright: ignore[reportUnknownParameterType, reportMissingParameterType] From bd14dd47dbf5c15360125cce16be3f3cee277612 Mon Sep 17 00:00:00 2001 From: Ben King Date: Thu, 29 Jan 2026 15:38:39 +0000 Subject: [PATCH 3/4] Multilingual inference for test step --- silnlp/common/translate_google.py | 2 +- silnlp/common/translator.py | 143 ++++++++++++++++++++++++++---- silnlp/nmt/config.py | 20 ++++- silnlp/nmt/hugging_face_config.py | 89 +++++++++---------- silnlp/nmt/test.py | 8 +- 5 files changed, 189 insertions(+), 73 deletions(-) diff --git a/silnlp/common/translate_google.py b/silnlp/common/translate_google.py index 09772044..bd477d90 100644 --- a/silnlp/common/translate_google.py +++ b/silnlp/common/translate_google.py @@ -25,7 +25,7 @@ def translate( LOGGER.warning("Google Translator does not support --multiple-translations") for sentence in sentences: - if len(sentence.text) == 0: + if sentence.text is None or len(sentence.text) == 0: yield [SentenceTranslation("", [], [], None)] else: results = self._translate_client.translate( diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 19993833..5fcb4ced 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -43,11 +43,97 @@ @dataclass class TranslationInputSentence: - text: str - src_iso: str - trg_iso: str - scripture_ref: ScriptureRef | None = None - vref: VerseRef | None = None + def __init__( + self, + text: str | None = None, + tokens: List[str] | None = None, + src_iso: str = "", + trg_iso: str = "", + scripture_ref: ScriptureRef | None = None, + vref: VerseRef | None = None, + ): + self._text = text + self._tokens = tokens + self._src_iso = src_iso + self._trg_iso = trg_iso + self._scripture_ref = scripture_ref + self._vref = vref + + @property + def text(self) -> str | None: + return self._text + + @property + def tokens(self) -> List[str] | None: + return self._tokens + + @property + def src_iso(self) -> str: + return self._src_iso + + @property + def trg_iso(self) -> str: + return self._trg_iso + + @property + def scripture_ref(self) -> ScriptureRef | None: + return self._scripture_ref + + @property + def vref(self) -> VerseRef | None: + return self._vref + + def has_tokens(self) -> bool: + return self.tokens is not None + + class Builder: + def __init__(self): + self._text = None + self._tokens = None + self._src_iso = None + self._trg_iso = None + self._scripture_ref = None + self._vref = None + + def set_text(self, text: str) -> "TranslationInputSentence.Builder": + self._text = text + return self + + def set_tokens(self, tokens: List[str]) -> "TranslationInputSentence.Builder": + self._tokens = tokens + return self + + def set_src_iso(self, src_iso: str) -> "TranslationInputSentence.Builder": + self._src_iso = src_iso + return self + + def set_trg_iso(self, trg_iso: str) -> "TranslationInputSentence.Builder": + self._trg_iso = trg_iso + return self + + def set_scripture_ref(self, scripture_ref: ScriptureRef) -> "TranslationInputSentence.Builder": + self._scripture_ref = scripture_ref + return self + + def set_verse_ref(self, vref: VerseRef) -> "TranslationInputSentence.Builder": + self._vref = vref + return self + + def build(self) -> "TranslationInputSentence": + if self._text is None and self._tokens is None: + raise ValueError("TranslationInputSentence must have either text or tokens defined") + if self._src_iso is None: + raise ValueError("TranslationInputSentence must have a src_iso defined") + if self._trg_iso is None: + raise ValueError("TranslationInputSentence must have a trg_iso defined") + return TranslationInputSentence( + text=self._text, + tokens=self._tokens, + src_iso=self._src_iso, + trg_iso=self._trg_iso, + scripture_ref=self._scripture_ref, + vref=self._vref, + ) # A single translation of a single sentence @@ -293,7 +379,11 @@ def translate_text( ) -> None: translation_inputs = [ - TranslationInputSentence(add_tags_to_sentence(tags, sentence), src_iso, trg_iso) + TranslationInputSentence.Builder() + .set_text(add_tags_to_sentence(tags, sentence)) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .build() for sentence in load_corpus(src_file_path) ] sentence_translation_groups: List[SentenceTranslationGroup] = list( @@ -397,26 +487,37 @@ def translate_usfm( src_file_text = UsfmFileText(stylesheet, "utf-8-sig", book_id, src_file_path, include_all_text=True) sentences = [ - TranslationInputSentence( - re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip())), src_iso, trg_iso, s.ref, s.ref.verse_ref - ) + TranslationInputSentence.Builder() + .set_text(re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip()))) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .set_scripture_ref(s.ref) + .set_verse_ref(s.ref.verse_ref) + .build() for s in src_file_text ] LOGGER.info(f"File {src_file_path} parsed correctly.") # Filter sentences for i in reversed(range(len(sentences))): - marker = sentences[i].scripture_ref.path[-1].name if len(sentences[i].scripture_ref.path) > 0 else "" + sentence_scripture_ref = sentences[i].scripture_ref + if sentence_scripture_ref is None: + continue + marker = sentence_scripture_ref.path[-1].name if len(sentence_scripture_ref.path) > 0 else "" if ( - (len(chapters) > 0 and sentences[i].scripture_ref.chapter_num not in chapters) + (len(chapters) > 0 and sentence_scripture_ref.chapter_num not in chapters) or marker in PARAGRAPH_TYPE_EMBEDS or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT ): sentences.pop(i) empty_sents: List[Tuple[int, ScriptureRef]] = [] for i in reversed(range(len(sentences))): - if len(sentences[i].text.strip()) == 0: - empty_sents.append((i, sentences[i].scripture_ref)) + sentence_scripture_ref = sentences[i].scripture_ref + if sentence_scripture_ref is None: + continue + sentence_text = sentences[i].text + if (sentence_text is None or len(sentence_text.strip()) == 0) and sentence_scripture_ref is not None: + empty_sents.append((i, sentence_scripture_ref)) sentences.pop(i) sentence_translation_groups: List[SentenceTranslationGroup] = list( @@ -427,7 +528,7 @@ def translate_usfm( # Add empty sentences back in # Prevents pre-existing text from showing up in the sections of translated text for idx, vref in reversed(empty_sents): - sentences.insert(idx, TranslationInputSentence("", "", "", vref, vref.verse_ref)) + sentences.insert(idx, TranslationInputSentence(None, None, "", "", vref, vref.verse_ref)) sentence_translation_groups.insert(idx, [SentenceTranslation("", [], [], None)] * num_drafts) text_behavior = ( @@ -438,15 +539,17 @@ def translate_usfm( for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): postprocess_handler.construct_rows( [s.scripture_ref for s in sentences if s.scripture_ref is not None], - [s.text for s in sentences], + [s.text or "" for s in sentences], translated_draft.get_all_translations(), ) for config in postprocess_handler.configs: + first_scripture_ref = sentences[0].scripture_ref + assert first_scripture_ref is not None # Compile draft remarks draft_src_str = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" - draft_remark = f"This draft of {sentences[0].scripture_ref.book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." + draft_remark = f"This draft of {first_scripture_ref.book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." postprocess_remark = config.get_postprocess_remark() remarks = [draft_remark] + ([postprocess_remark] if postprocess_remark else []) @@ -481,7 +584,7 @@ def translate_usfm( usfm = f.read() handler = UpdateUsfmParserHandler( rows=config.rows, - id_text=sentences[0].scripture_ref.book, + id_text=first_scripture_ref.book, text_behavior=text_behavior, paragraph_behavior=config.get_paragraph_behavior(), embed_behavior=config.get_embed_behavior(), @@ -558,7 +661,11 @@ def translate_docx( for i, paragraph in enumerate(doc.paragraphs): for sentence in tokenizer.tokenize(paragraph.text): translation_inputs.append( - TranslationInputSentence(add_tags_to_sentence(tags, sentence), src_iso, trg_iso) + TranslationInputSentence.Builder() + .set_text(add_tags_to_sentence(tags, sentence)) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .build() ) paras.append(i) diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 12785415..5585ab8f 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -632,12 +632,16 @@ def _write_scripture_data_sets( project = column[len("target_") :] self._append_corpus( self.test_trg_filename(src_iso, trg_iso, project), + tokenizer.tokenize_all(Side.TARGET, pair_test[column]), + ) + self._append_corpus( + self.test_trg_detok_filename(src_iso, trg_iso, project), tokenizer.normalize_all(Side.TARGET, pair_test[column]), ) test_projects.remove(project) if self._has_multiple_test_projects(src_iso, trg_iso): for project in test_projects: - self._fill_corpus(self.test_trg_filename(src_iso, trg_iso, project), len(pair_test)) + self._fill_corpus(self.test_trg_detok_filename(src_iso, trg_iso, project), len(pair_test)) LOGGER.info(f"train size: {train_count}," f" val size: {val_count}," f" test size: {test_count},") return train_count @@ -1005,6 +1009,9 @@ def _write_basic_data_file_pair( val_trg_file = stack.enter_context(self._open_append(self.val_trg_filename())) test_src_file = stack.enter_context(self._open_append(self.test_src_filename(src_file.iso, trg_file.iso))) test_trg_file = stack.enter_context(self._open_append(self.test_trg_filename(src_file.iso, trg_file.iso))) + test_trg_detok_file = stack.enter_context( + self._open_append(self.test_trg_detok_filename(src_file.iso, trg_file.iso)) + ) train_vref_file: Optional[TextIO] = None val_vref_file: Optional[TextIO] = None @@ -1024,7 +1031,7 @@ def _write_basic_data_file_pair( if self._has_multiple_test_projects(src_file.iso, trg_file.iso): test_trg_project_files = [ stack.enter_context( - self._open_append(self.test_trg_filename(src_file.iso, trg_file.iso, project)) + self._open_append(self.test_trg_detok_filename(src_file.iso, trg_file.iso, project)) ) for project in test_projects if project != BASIC_DATA_PROJECT @@ -1051,7 +1058,8 @@ def _write_basic_data_file_pair( if pair.is_test and (test_indices is None or index in test_indices): test_src_file.write(tokenizer.tokenize(Side.SOURCE, src_sentence) + "\n") - test_trg_file.write(tokenizer.normalize(Side.TARGET, trg_sentence) + "\n") + test_trg_file.write(tokenizer.tokenize(Side.TARGET, trg_sentence) + "\n") + test_trg_detok_file.write(tokenizer.normalize(Side.TARGET, trg_sentence) + "\n") if test_vref_file is not None: test_vref_file.write("\n") for test_trg_project_file in test_trg_project_files: @@ -1221,6 +1229,12 @@ def test_vref_filename(self, src_iso: str, trg_iso: str) -> str: return f"test.{src_iso}.{trg_iso}.vref.txt" if self._multiple_test_iso_pairs else "test.vref.txt" def test_trg_filename(self, src_iso: str, trg_iso: str, project: str = BASIC_DATA_PROJECT) -> str: + prefix = f"test.{src_iso}.{trg_iso}" if self._multiple_test_iso_pairs else "test" + has_multiple_test_projects = self._iso_pairs[(src_iso, trg_iso)].has_multiple_test_projects + suffix = f".{project}" if has_multiple_test_projects else "" + return f"{prefix}.trg{suffix}.txt" + + def test_trg_detok_filename(self, src_iso: str, trg_iso: str, project: str = BASIC_DATA_PROJECT) -> str: prefix = f"test.{src_iso}.{trg_iso}" if self._multiple_test_iso_pairs else "test" has_multiple_test_projects = self._iso_pairs[(src_iso, trg_iso)].has_multiple_test_projects suffix = f".{project}" if has_multiple_test_projects else "" diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 58c681b2..afc65f7f 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -11,7 +11,7 @@ from itertools import repeat from math import prod from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast import datasets.utils.logging as datasets_logging import evaluate @@ -50,6 +50,7 @@ PreTrainedTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, + SpecialTokensMixin, T5Tokenizer, T5TokenizerFast, TensorType, @@ -73,7 +74,7 @@ ) from transformers.utils.logging import tqdm -from ..common.corpus import Term, count_lines, get_terms +from ..common.corpus import Term, get_terms from ..common.environment import SIL_NLP_ENV from ..common.translator import ( DraftGroup, @@ -884,16 +885,10 @@ def convert_to_sentence_translation_group(self, tokenizer: PreTrainedTokenizer) @dataclass class InferenceModelParams: checkpoint: Union[CheckpointType, str, int] - src_lang: str - trg_lang: str def __post_init__(self): if not isinstance(self.checkpoint, (CheckpointType, str, int)): raise ValueError("checkpoint must be a CheckpointType, string, or integer") - if not isinstance(self.src_lang, str): - raise ValueError("src_lang must be a string") - if not isinstance(self.trg_lang, str): - raise ValueError("trg_lang must be a string") class HuggingFaceNMTModel(NMTModel): @@ -1178,7 +1173,7 @@ def translate_test_files( ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> None: tokenizer = self._config.get_tokenizer() - model = self._create_inference_model(ckpt, tokenizer, self._config.test_src_lang, self._config.test_trg_lang) + model = self._create_inference_model(ckpt, tokenizer) pipeline = PretokenizedTranslationPipeline( model=model, tokenizer=tokenizer, @@ -1195,10 +1190,12 @@ def translate_test_files( src_file = stack.enter_context(input_path.open("r", encoding="utf-8-sig")) sentences = [line.strip().split() for line in src_file] src_isos = [sentence[0] for sentence in sentences] - sentences = [" ".join(sentence[1:]) for sentence in sentences] - gold_trg_file = stack.enter_context(test_gold_standard_path.open("r", encoding="utf-8-sig")) - trg_isos = [line.strip().split()[0] for line in gold_trg_file] + if not test_gold_standard_path.exists(): + trg_isos = [self._config.test_trg_lang for _ in sentences] + else: + gold_trg_file = stack.enter_context(test_gold_standard_path.open("r", encoding="utf-8-sig")) + trg_isos = [line.strip().split()[0] for line in gold_trg_file] vrefs: Optional[Iterable[VerseRef]] = None if vref_path is not None: @@ -1206,7 +1203,12 @@ def translate_test_files( vrefs = (VerseRef.from_string(line.strip(), ORIGINAL_VERSIFICATION) for line in vref_file) translation_inputs = [ - TranslationInputSentence(src_iso, trg_iso, sentence, None, vref) + TranslationInputSentence.Builder() + .set_tokens(sentence) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .set_verse_ref(vref) + .build() for src_iso, trg_iso, sentence, vref in zip( src_isos, trg_isos, sentences, vrefs if vrefs is not None else repeat(None) ) @@ -1271,16 +1273,12 @@ def translate( produce_multiple_translations: bool = False, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Generator[SentenceTranslationGroup, None, None]: - src_langs = [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences] - trg_langs = [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences] - inference_model_params = InferenceModelParams(ckpt, src_langs[0], trg_langs[0]) + inference_model_params = InferenceModelParams(ckpt) tokenizer = self._config.get_tokenizer() if self._inference_model_params == inference_model_params and self._cached_inference_model is not None: model = self._cached_inference_model else: - model = self._cached_inference_model = self._create_inference_model( - ckpt, tokenizer, src_langs[0], trg_langs[0] - ) + model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer) self._inference_model_params = inference_model_params if model.config.max_length is not None and model.config.max_length < 512: model.config.max_length = 512 @@ -1664,7 +1662,7 @@ def _translate_with_beam_search( [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], ) translations = pipeline( - [s.text for s in sentences], + [s.tokens if s.has_tokens() else s.text for s in sentences], num_beams=num_beams, num_return_sequences=num_return_sequences, force_words_ids=force_words_ids, @@ -1695,7 +1693,7 @@ def _translate_with_sampling( [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], ) translations = pipeline( - [s.text for s in sentences], + [s.tokens if s.has_tokens() else s.text for s in sentences], do_sample=True, temperature=temperature, num_return_sequences=num_return_sequences, @@ -1729,7 +1727,7 @@ def _translate_with_diverse_beam_search( [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], ) translations = pipeline( - [s.text for s in sentences], + [s.tokens if s.has_tokens() else s.text for s in sentences], num_beams=num_beams, num_beam_groups=num_beams, num_return_sequences=num_return_sequences, @@ -1749,8 +1747,8 @@ def _create_inference_model( self, ckpt: Union[CheckpointType, str, int], tokenizer: PreTrainedTokenizer, - src_lang: str, - trg_lang: str, + src_lang: str = "", + trg_lang: str = "", ) -> PreTrainedModel: if self._config.model_dir.exists(): checkpoint_path, _ = self.get_checkpoint_path(ckpt) @@ -1815,22 +1813,22 @@ def _configure_model( if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") - if ( + if ( src_lang != "" and trg_lang != "" and isinstance( tokenizer, (MBartTokenizer, MBartTokenizerFast, M2M100Tokenizer, NllbTokenizer, NllbTokenizerFast) ) - ): - tokenizer.src_lang = src_lang - tokenizer.tgt_lang = trg_lang + ): + tokenizer.src_lang = src_lang + tokenizer.tgt_lang = trg_lang - # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token - # as the first generated token. - forced_bos_token_id = tokenizer.convert_tokens_to_ids(trg_lang) - model.config.forced_bos_token_id = forced_bos_token_id - if model.generation_config is not None: - model.generation_config.forced_bos_token_id = forced_bos_token_id + # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token + # as the first generated token. + forced_bos_token_id = tokenizer.convert_tokens_to_ids(trg_lang) + model.config.forced_bos_token_id = forced_bos_token_id + if model.generation_config is not None: + model.generation_config.forced_bos_token_id = forced_bos_token_id return model, tokenizer @@ -1871,6 +1869,10 @@ def decode(self, *args, **kwargs): def set_src_lang(self, src_lang: str): self._wrapped_tokenizer.src_lang = src_lang + @SpecialTokensMixin.pad_token_id.getter + def pad_token_id(self) -> int | None: + return self._wrapped_tokenizer.pad_token_id + class HuggingFaceTokenizer(Tokenizer): def __init__( @@ -1974,24 +1976,16 @@ def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_l def _parse_and_tokenize(self, *args, truncation): prefix = self.prefix if self.prefix is not None else "" - if isinstance(args[0], list): # TODO: disallow this case - if self.tokenizer.pad_token_id is None: - raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input") - args = ([prefix + self.src_langs[self.src_index + i] + " " + arg for i, arg in enumerate(args[0])],) - self.tokenizer.src_lang = "" - self.src_index += len(args[0]) - padding = True + if isinstance(args[0], list): + raise ValueError("SilTranslationPipeline does not support batch tokenization") elif isinstance(args[0], str): args = (prefix + args[0],) self.tokenizer.set_src_lang(self.src_langs[self.src_index]) self.src_index += 1 - padding = False else: - raise ValueError( - f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`" - ) - inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors=self.framework) + raise ValueError("SilTranslationPipeline only supports string inputs for tokenization") + inputs = self.tokenizer(*args, padding=False, truncation=truncation, return_tensors=self.framework) # This is produced by tokenizers but is an invalid generate kwargs if "token_type_ids" in inputs: del inputs["token_type_ids"] @@ -2009,7 +2003,8 @@ def _forward(self, model_inputs, **generate_kwargs): self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) generate_kwargs["decoder_input_ids"] = torch.cat( ( - torch.ones((in_b, 1), dtype=torch.long, device=model_inputs["input_ids"].device) * 2, + torch.ones((in_b, 1), dtype=torch.long, device=model_inputs["input_ids"].device) + * self.tokenizer.pad_token_id, torch.unsqueeze(torch.from_numpy(self.tgt_langs[self.tgt_index : self.tgt_index + in_b]), 1).to( model_inputs["input_ids"].device ), diff --git a/silnlp/nmt/test.py b/silnlp/nmt/test.py index 05aa80fe..f169ee81 100644 --- a/silnlp/nmt/test.py +++ b/silnlp/nmt/test.py @@ -478,7 +478,7 @@ def test_checkpoint( vref_file_names: List[str] = [] source_file_names: List[str] = [] translation_file_names: List[str] = [] - gold_standard_detok_file_names: List[str] = [] + gold_standard_target_file_names: List[str] = [] refs_patterns: List[str] = [] translation_detok_file_names: List[str] = [] translation_conf_file_names: List[str] = [] @@ -496,7 +496,7 @@ def test_checkpoint( refs_patterns.append("test.trg.detok*.txt") translation_detok_file_names.append(f"test.trg-predictions.detok.txt.{suffix_str}") translation_conf_file_names.append(f"test.trg-predictions.txt.{suffix_str}.confidences.tsv") - gold_standard_detok_file_names.append("test.trg.detok.txt") + gold_standard_target_file_names.append("test.trg.txt") else: # test data is split into separate files for src_iso in sorted(config.test_src_isos): @@ -512,7 +512,7 @@ def test_checkpoint( refs_patterns.append(f"{prefix}.trg.detok*.txt") translation_detok_file_names.append(f"{prefix}.trg-predictions.detok.txt.{suffix_str}") translation_conf_file_names.append(f"{prefix}.trg-predictions.txt.{suffix_str}.confidences.tsv") - gold_standard_detok_file_names.append(f"{prefix}.trg.detok.txt") + gold_standard_target_file_names.append(f"{prefix}.trg.txt") checkpoint_name = "averaged checkpoint" if step == -1 else f"checkpoint {step}" @@ -522,7 +522,7 @@ def test_checkpoint( translation_paths: List[Path] = [] for i in range(len(translation_file_names)): predictions_path = config.exp_dir / translation_file_names[i] - gold_standard_target_paths.append(config.exp_dir / gold_standard_detok_file_names[i]) + gold_standard_target_paths.append(config.exp_dir / gold_standard_target_file_names[i]) if force_infer or not predictions_path.is_file(): source_paths.append(config.exp_dir / source_file_names[i]) translation_paths.append(predictions_path) From 8b85778ac6df3a4bece3d7e0db8d1201e9eefe5b Mon Sep 17 00:00:00 2001 From: Ben King Date: Tue, 3 Feb 2026 19:01:31 +0000 Subject: [PATCH 4/4] Use correct start token for decoder --- silnlp/nmt/hugging_face_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index afc65f7f..dc332be8 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -1869,9 +1869,9 @@ def decode(self, *args, **kwargs): def set_src_lang(self, src_lang: str): self._wrapped_tokenizer.src_lang = src_lang - @SpecialTokensMixin.pad_token_id.getter - def pad_token_id(self) -> int | None: - return self._wrapped_tokenizer.pad_token_id + @SpecialTokensMixin.eos_token_id.getter + def eos_token_id(self) -> int | None: + return self._wrapped_tokenizer.eos_token_id class HuggingFaceTokenizer(Tokenizer): @@ -2004,7 +2004,7 @@ def _forward(self, model_inputs, **generate_kwargs): generate_kwargs["decoder_input_ids"] = torch.cat( ( torch.ones((in_b, 1), dtype=torch.long, device=model_inputs["input_ids"].device) - * self.tokenizer.pad_token_id, + * self.tokenizer.eos_token_id, torch.unsqueeze(torch.from_numpy(self.tgt_langs[self.tgt_index : self.tgt_index + in_b]), 1).to( model_inputs["input_ids"].device ),