diff --git a/silnlp/common/translate_google.py b/silnlp/common/translate_google.py index 9cf3bd97..bd477d90 100644 --- a/silnlp/common/translate_google.py +++ b/silnlp/common/translate_google.py @@ -6,7 +6,7 @@ from machine.scripture import VerseRef, book_id_to_number from .paratext import book_file_name_digits -from .translator import SentenceTranslation, SentenceTranslationGroup, Translator +from .translator import SentenceTranslation, SentenceTranslationGroup, TranslationInputSentence, Translator from .utils import get_git_revision_hash, get_mt_exp_dir LOGGER = logging.getLogger((__package__ or "") + ".translate") @@ -18,21 +18,18 @@ def __init__(self) -> None: def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: Iterable[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: if produce_multiple_translations: LOGGER.warning("Google Translator does not support --multiple-translations") for sentence in sentences: - if len(sentence) == 0: - yield "" + if sentence.text is None or len(sentence.text) == 0: + yield [SentenceTranslation("", [], [], None)] else: results = self._translate_client.translate( - sentence, source_language=src_iso, target_language=trg_iso, format_="text" + sentence.text, source_language=sentence.src_iso, target_language=sentence.trg_iso, format_="text" ) translation: str = results["translatedText"] yield [SentenceTranslation(translation, [], [], None)] diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 51c32879..5fcb4ced 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -3,11 +3,12 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import AbstractContextManager +from dataclasses import dataclass from datetime import date from itertools import groupby from math import exp from pathlib import Path -from typing import DefaultDict, Generator, Iterable, List, Optional, Tuple, cast +from typing import DefaultDict, Generator, List, Optional, Tuple import docx import nltk @@ -40,6 +41,101 @@ CONFIDENCE_SCORES_SUFFIX = ".confidences.tsv" +@dataclass +class TranslationInputSentence: + def __init__( + self, + text: str | None = None, + tokens: List[str] | None = None, + src_iso: str = "", + trg_iso: str = "", + scripture_ref: ScriptureRef | None = None, + vref: VerseRef | None = None, + ): + self._text = text + self._tokens = tokens + self._src_iso = src_iso + self._trg_iso = trg_iso + self._scripture_ref = scripture_ref + self._vref = vref + + @property + def text(self) -> str | None: + return self._text + + @property + def tokens(self) -> List[str] | None: + return self._tokens + + @property + def src_iso(self) -> str: + return self._src_iso + + @property + def trg_iso(self) -> str: + return self._trg_iso + + @property + def scripture_ref(self) -> ScriptureRef | None: + return self._scripture_ref + + @property + def vref(self) -> VerseRef | None: + return self._vref + + def has_tokens(self) -> bool: + return self.tokens is not None + + class Builder: + def __init__(self): + self._text = None + self._tokens = None + self._src_iso = None + self._trg_iso = None + self._scripture_ref = None + self._vref = None + + def set_text(self, text: str) -> "TranslationInputSentence.Builder": + self._text = text + return self + + def set_tokens(self, tokens: List[str]) -> "TranslationInputSentence.Builder": + self._tokens = tokens + return self + + def set_src_iso(self, src_iso: str) -> "TranslationInputSentence.Builder": + self._src_iso = src_iso + return self + + def set_trg_iso(self, trg_iso: str) -> "TranslationInputSentence.Builder": + self._trg_iso = trg_iso + return self + + def set_scripture_ref(self, scripture_ref: ScriptureRef) -> "TranslationInputSentence.Builder": + self._scripture_ref = scripture_ref + return self + + def set_verse_ref(self, vref: VerseRef) -> "TranslationInputSentence.Builder": + self._vref = vref + return self + + def build(self) -> "TranslationInputSentence": + if self._text is None and self._tokens is None: + raise ValueError("TranslationInputSentence must have either text or tokens defined") + if self._src_iso is None: + raise ValueError("TranslationInputSentence must have a src_iso defined") + if self._trg_iso is None: + raise ValueError("TranslationInputSentence must have a trg_iso defined") + return TranslationInputSentence( + text=self._text, + tokens=self._tokens, + src_iso=self._src_iso, + trg_iso=self._trg_iso, + scripture_ref=self._scripture_ref, + vref=self._vref, + ) + + # A single translation of a single sentence class SentenceTranslation: def __init__( @@ -265,11 +361,8 @@ class Translator(AbstractContextManager["Translator"], ABC): @abstractmethod def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: pass @@ -285,9 +378,16 @@ def translate_text( tags: Optional[List[str]] = None, ) -> None: - sentences = [add_tags_to_sentence(tags, sentence) for sentence in load_corpus(src_file_path)] + translation_inputs = [ + TranslationInputSentence.Builder() + .set_text(add_tags_to_sentence(tags, sentence)) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .build() + for sentence in load_corpus(src_file_path) + ] sentence_translation_groups: List[SentenceTranslationGroup] = list( - self.translate(sentences, src_iso, trg_iso, produce_multiple_translations) + self.translate(translation_inputs, produce_multiple_translations) ) draft_set = DraftGroup(sentence_translation_groups) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): @@ -386,37 +486,49 @@ def translate_usfm( src_file_text = UsfmFileText(stylesheet, "utf-8-sig", book_id, src_file_path, include_all_text=True) - sentences = [re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip())) for s in src_file_text] - scripture_refs: List[ScriptureRef] = [s.ref for s in src_file_text] - vrefs: List[VerseRef] = [sr.verse_ref for sr in scripture_refs] + sentences = [ + TranslationInputSentence.Builder() + .set_text(re.sub(" +", " ", add_tags_to_sentence(tags, s.text.strip()))) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .set_scripture_ref(s.ref) + .set_verse_ref(s.ref.verse_ref) + .build() + for s in src_file_text + ] LOGGER.info(f"File {src_file_path} parsed correctly.") # Filter sentences for i in reversed(range(len(sentences))): - marker = scripture_refs[i].path[-1].name if len(scripture_refs[i].path) > 0 else "" + sentence_scripture_ref = sentences[i].scripture_ref + if sentence_scripture_ref is None: + continue + marker = sentence_scripture_ref.path[-1].name if len(sentence_scripture_ref.path) > 0 else "" if ( - (len(chapters) > 0 and scripture_refs[i].chapter_num not in chapters) + (len(chapters) > 0 and sentence_scripture_ref.chapter_num not in chapters) or marker in PARAGRAPH_TYPE_EMBEDS or stylesheet.get_tag(marker).text_type == UsfmTextType.NOTE_TEXT ): sentences.pop(i) - scripture_refs.pop(i) empty_sents: List[Tuple[int, ScriptureRef]] = [] for i in reversed(range(len(sentences))): - if len(sentences[i].strip()) == 0: + sentence_scripture_ref = sentences[i].scripture_ref + if sentence_scripture_ref is None: + continue + sentence_text = sentences[i].text + if (sentence_text is None or len(sentence_text.strip()) == 0) and sentence_scripture_ref is not None: + empty_sents.append((i, sentence_scripture_ref)) sentences.pop(i) - empty_sents.append((i, scripture_refs.pop(i))) sentence_translation_groups: List[SentenceTranslationGroup] = list( - self.translate(sentences, src_iso, trg_iso, produce_multiple_translations, vrefs) + self.translate(sentences, produce_multiple_translations) ) num_drafts = len(sentence_translation_groups[0]) # Add empty sentences back in # Prevents pre-existing text from showing up in the sections of translated text for idx, vref in reversed(empty_sents): - sentences.insert(idx, "") - scripture_refs.insert(idx, vref) + sentences.insert(idx, TranslationInputSentence(None, None, "", "", vref, vref.verse_ref)) sentence_translation_groups.insert(idx, [SentenceTranslation("", [], [], None)] * num_drafts) text_behavior = ( @@ -425,13 +537,19 @@ def translate_usfm( draft_set: DraftGroup = DraftGroup(sentence_translation_groups) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): - postprocess_handler.construct_rows(scripture_refs, sentences, translated_draft.get_all_translations()) + postprocess_handler.construct_rows( + [s.scripture_ref for s in sentences if s.scripture_ref is not None], + [s.text or "" for s in sentences], + translated_draft.get_all_translations(), + ) for config in postprocess_handler.configs: + first_scripture_ref = sentences[0].scripture_ref + assert first_scripture_ref is not None # Compile draft remarks draft_src_str = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" - draft_remark = f"This draft of {scripture_refs[0].book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." + draft_remark = f"This draft of {first_scripture_ref.book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." postprocess_remark = config.get_postprocess_remark() remarks = [draft_remark] + ([postprocess_remark] if postprocess_remark else []) @@ -466,7 +584,7 @@ def translate_usfm( usfm = f.read() handler = UpdateUsfmParserHandler( rows=config.rows, - id_text=scripture_refs[0].book, + id_text=first_scripture_ref.book, text_behavior=text_behavior, paragraph_behavior=config.get_paragraph_behavior(), embed_behavior=config.get_embed_behavior(), @@ -513,7 +631,7 @@ def translate_usfm( translated_draft, trg_file_path, produce_multiple_translations=produce_multiple_translations, - scripture_refs=scripture_refs, + scripture_refs=[s.scripture_ref for s in sentences if s.scripture_ref is not None], draft_index=draft_index, ) @@ -537,17 +655,21 @@ def translate_docx( with src_file_path.open("rb") as file: doc = docx.Document(file) - sentences: List[str] = [] + translation_inputs: List[TranslationInputSentence] = [] paras: List[int] = [] for i, paragraph in enumerate(doc.paragraphs): for sentence in tokenizer.tokenize(paragraph.text): - sentences.append(add_tags_to_sentence(tags, sentence)) + translation_inputs.append( + TranslationInputSentence.Builder() + .set_text(add_tags_to_sentence(tags, sentence)) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .build() + ) paras.append(i) - draft_set: DraftGroup = DraftGroup( - list(self.translate(sentences, src_iso, trg_iso, produce_multiple_translations)) - ) + draft_set: DraftGroup = DraftGroup(list(self.translate(translation_inputs, produce_multiple_translations))) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): for para, group in groupby(zip(translated_draft.get_all_translations(), paras), key=lambda t: t[1]): diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index d2b24538..5585ab8f 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -32,7 +32,7 @@ write_corpus, ) from ..common.environment import SIL_NLP_ENV -from ..common.translator import SentenceTranslationGroup +from ..common.translator import SentenceTranslationGroup, TranslationInputSentence from ..common.utils import NoiseMethod, Side, add_tags_to_dataframe, add_tags_to_sentence, get_mt_exp_dir, set_seed from .augment import AugmentMethod from .corpora import ( @@ -71,6 +71,7 @@ def save_effective_config(self, path: Path) -> None: ... def translate_test_files( self, input_paths: List[Path], + test_gold_standard_paths: List[Path], translation_paths: List[Path], produce_multiple_translations: bool = False, save_confidences: bool = False, @@ -81,11 +82,8 @@ def translate_test_files( @abstractmethod def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Generator[SentenceTranslationGroup, None, None]: ... @@ -634,12 +632,16 @@ def _write_scripture_data_sets( project = column[len("target_") :] self._append_corpus( self.test_trg_filename(src_iso, trg_iso, project), + tokenizer.tokenize_all(Side.TARGET, pair_test[column]), + ) + self._append_corpus( + self.test_trg_detok_filename(src_iso, trg_iso, project), tokenizer.normalize_all(Side.TARGET, pair_test[column]), ) test_projects.remove(project) if self._has_multiple_test_projects(src_iso, trg_iso): for project in test_projects: - self._fill_corpus(self.test_trg_filename(src_iso, trg_iso, project), len(pair_test)) + self._fill_corpus(self.test_trg_detok_filename(src_iso, trg_iso, project), len(pair_test)) LOGGER.info(f"train size: {train_count}," f" val size: {val_count}," f" test size: {test_count},") return train_count @@ -1007,6 +1009,9 @@ def _write_basic_data_file_pair( val_trg_file = stack.enter_context(self._open_append(self.val_trg_filename())) test_src_file = stack.enter_context(self._open_append(self.test_src_filename(src_file.iso, trg_file.iso))) test_trg_file = stack.enter_context(self._open_append(self.test_trg_filename(src_file.iso, trg_file.iso))) + test_trg_detok_file = stack.enter_context( + self._open_append(self.test_trg_detok_filename(src_file.iso, trg_file.iso)) + ) train_vref_file: Optional[TextIO] = None val_vref_file: Optional[TextIO] = None @@ -1026,7 +1031,7 @@ def _write_basic_data_file_pair( if self._has_multiple_test_projects(src_file.iso, trg_file.iso): test_trg_project_files = [ stack.enter_context( - self._open_append(self.test_trg_filename(src_file.iso, trg_file.iso, project)) + self._open_append(self.test_trg_detok_filename(src_file.iso, trg_file.iso, project)) ) for project in test_projects if project != BASIC_DATA_PROJECT @@ -1053,7 +1058,8 @@ def _write_basic_data_file_pair( if pair.is_test and (test_indices is None or index in test_indices): test_src_file.write(tokenizer.tokenize(Side.SOURCE, src_sentence) + "\n") - test_trg_file.write(tokenizer.normalize(Side.TARGET, trg_sentence) + "\n") + test_trg_file.write(tokenizer.tokenize(Side.TARGET, trg_sentence) + "\n") + test_trg_detok_file.write(tokenizer.normalize(Side.TARGET, trg_sentence) + "\n") if test_vref_file is not None: test_vref_file.write("\n") for test_trg_project_file in test_trg_project_files: @@ -1223,6 +1229,12 @@ def test_vref_filename(self, src_iso: str, trg_iso: str) -> str: return f"test.{src_iso}.{trg_iso}.vref.txt" if self._multiple_test_iso_pairs else "test.vref.txt" def test_trg_filename(self, src_iso: str, trg_iso: str, project: str = BASIC_DATA_PROJECT) -> str: + prefix = f"test.{src_iso}.{trg_iso}" if self._multiple_test_iso_pairs else "test" + has_multiple_test_projects = self._iso_pairs[(src_iso, trg_iso)].has_multiple_test_projects + suffix = f".{project}" if has_multiple_test_projects else "" + return f"{prefix}.trg{suffix}.txt" + + def test_trg_detok_filename(self, src_iso: str, trg_iso: str, project: str = BASIC_DATA_PROJECT) -> str: prefix = f"test.{src_iso}.{trg_iso}" if self._multiple_test_iso_pairs else "test" has_multiple_test_projects = self._iso_pairs[(src_iso, trg_iso)].has_multiple_test_projects suffix = f".{project}" if has_multiple_test_projects else "" diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 0ab5af1a..dc332be8 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -11,7 +11,7 @@ from itertools import repeat from math import prod from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast import datasets.utils.logging as datasets_logging import evaluate @@ -50,11 +50,12 @@ PreTrainedTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, + SpecialTokensMixin, T5Tokenizer, T5TokenizerFast, TensorType, + Text2TextGenerationPipeline, TrainerCallback, - TranslationPipeline, set_seed, ) from transformers.convert_slow_tokenizer import convert_slow_tokenizer @@ -73,12 +74,13 @@ ) from transformers.utils.logging import tqdm -from ..common.corpus import Term, count_lines, get_terms +from ..common.corpus import Term, get_terms from ..common.environment import SIL_NLP_ENV from ..common.translator import ( DraftGroup, SentenceTranslation, SentenceTranslationGroup, + TranslationInputSentence, generate_test_confidence_files, ) from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, get_mt_exp_dir, merge_dict @@ -815,20 +817,16 @@ def batch_prepare_for_model( return BatchEncoding(batch_outputs, tensor_type=return_tensors) -TSent = TypeVar("TSent") - - def batch_sentences( - sentences: Iterable[TSent], - vrefs: Optional[Iterable[VerseRef]], + sentences: Iterable[TranslationInputSentence], batch_size: int, dictionary: Dict[VerseRef, Set[str]], -) -> Iterable[Tuple[List[TSent], Optional[List[List[List[str]]]]]]: - batch: List[TSent] = [] - for sentence, vref in zip(sentences, repeat(None) if vrefs is None else vrefs): +) -> Iterable[Tuple[List[TranslationInputSentence], Optional[List[List[List[str]]]]]]: + batch: List[TranslationInputSentence] = [] + for sentence in sentences: terms: Set[str] = set() - if vref is not None: - for vr in vref.all_verses(): + if sentence.vref is not None: + for vr in sentence.vref.all_verses(): terms.update(dictionary.get(vr, set())) if len(terms) > 0: if len(batch) > 0: @@ -887,16 +885,10 @@ def convert_to_sentence_translation_group(self, tokenizer: PreTrainedTokenizer) @dataclass class InferenceModelParams: checkpoint: Union[CheckpointType, str, int] - src_lang: str - trg_lang: str def __post_init__(self): if not isinstance(self.checkpoint, (CheckpointType, str, int)): raise ValueError("checkpoint must be a CheckpointType, string, or integer") - if not isinstance(self.src_lang, str): - raise ValueError("src_lang must be a string") - if not isinstance(self.trg_lang, str): - raise ValueError("trg_lang must be a string") class HuggingFaceNMTModel(NMTModel): @@ -1173,6 +1165,7 @@ def save_effective_config(self, path: Path) -> None: def translate_test_files( self, input_paths: List[Path], + test_gold_standard_paths: List[Path], translation_paths: List[Path], produce_multiple_translations: bool = False, save_confidences: bool = False, @@ -1180,31 +1173,49 @@ def translate_test_files( ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> None: tokenizer = self._config.get_tokenizer() - model = self._create_inference_model(ckpt, tokenizer, self._config.test_src_lang, self._config.test_trg_lang) + model = self._create_inference_model(ckpt, tokenizer) pipeline = PretokenizedTranslationPipeline( model=model, tokenizer=tokenizer, - src_lang=self._config.test_src_lang, - tgt_lang=self._config.test_trg_lang, device=0, ) pipeline.model = torch.compile(pipeline.model) - for input_path, translation_path, vref_path in zip( + for input_path, test_gold_standard_path, translation_path, vref_path in zip( input_paths, + test_gold_standard_paths, translation_paths, cast(Iterable[Optional[Path]], repeat(None) if vref_paths is None else vref_paths), ): - length = count_lines(input_path) with ExitStack() as stack: src_file = stack.enter_context(input_path.open("r", encoding="utf-8-sig")) - sentences = (line.strip().split() for line in src_file) + sentences = [line.strip().split() for line in src_file] + src_isos = [sentence[0] for sentence in sentences] + + if not test_gold_standard_path.exists(): + trg_isos = [self._config.test_trg_lang for _ in sentences] + else: + gold_trg_file = stack.enter_context(test_gold_standard_path.open("r", encoding="utf-8-sig")) + trg_isos = [line.strip().split()[0] for line in gold_trg_file] + vrefs: Optional[Iterable[VerseRef]] = None if vref_path is not None: vref_file = stack.enter_context(vref_path.open("r", encoding="utf-8-sig")) vrefs = (VerseRef.from_string(line.strip(), ORIGINAL_VERSIFICATION) for line in vref_file) + + translation_inputs = [ + TranslationInputSentence.Builder() + .set_tokens(sentence) + .set_src_iso(src_iso) + .set_trg_iso(trg_iso) + .set_verse_ref(vref) + .build() + for src_iso, trg_iso, sentence, vref in zip( + src_isos, trg_isos, sentences, vrefs if vrefs is not None else repeat(None) + ) + ] sentence_translation_groups: List[SentenceTranslationGroup] = list( self._translate_test_sentences( - tokenizer, pipeline, sentences, vrefs, length, produce_multiple_translations + tokenizer, pipeline, translation_inputs, produce_multiple_translations ) ) draft_group = DraftGroup(sentence_translation_groups) @@ -1230,10 +1241,8 @@ def translate_test_files( def _translate_test_sentences( self, tokenizer: PreTrainedTokenizer, - pipeline: TranslationPipeline, - sentences: Iterable[List[str]], - vrefs: Iterable[VerseRef], - length: int, + pipeline: "SilTranslationPipeline", + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, ) -> Iterable[SentenceTranslationGroup]: num_drafts = self.get_num_drafts() @@ -1247,9 +1256,9 @@ def _translate_test_sentences( for model_output_group in tqdm( self._translate_sentences( - tokenizer, pipeline, sentences, vrefs, produce_multiple_translations, return_tensors=True + tokenizer, pipeline, sentences, produce_multiple_translations, return_tensors=True ), - total=length, + total=len(sentences), unit="ex", ): yield model_output_group.convert_to_sentence_translation_group(tokenizer) @@ -1260,21 +1269,16 @@ def get_num_drafts(self) -> int: def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, ) -> Generator[SentenceTranslationGroup, None, None]: - src_lang = self._config.data["lang_codes"].get(src_iso, src_iso) - trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso) - inference_model_params = InferenceModelParams(ckpt, src_lang, trg_lang) + inference_model_params = InferenceModelParams(ckpt) tokenizer = self._config.get_tokenizer() if self._inference_model_params == inference_model_params and self._cached_inference_model is not None: model = self._cached_inference_model else: - model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang) + model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer) self._inference_model_params = inference_model_params if model.config.max_length is not None and model.config.max_length < 512: model.config.max_length = 512 @@ -1287,8 +1291,6 @@ def translate( pipeline = SilTranslationPipeline( model=model, tokenizer=tokenizer, - src_lang=src_lang, - tgt_lang=trg_lang, device=0, ) @@ -1305,7 +1307,7 @@ def translate( if not isinstance(sentences, list): sentences = list(sentences) for model_output_group in tqdm( - self._translate_sentences(tokenizer, pipeline, sentences, vrefs, produce_multiple_translations), + self._translate_sentences(tokenizer, pipeline, sentences, produce_multiple_translations), total=len(sentences), unit="ex", ): @@ -1510,16 +1512,15 @@ def _merge_and_delete_adapter(self, checkpoint_path: Path, vocab_size: int, save def _translate_sentences( self, tokenizer: PreTrainedTokenizer, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], - vrefs: Optional[Iterable[VerseRef]], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], produce_multiple_translations: bool = False, return_tensors: bool = False, ) -> Iterable[ModelOutputGroup]: batch_size: int = self._config.infer["infer_batch_size"] dictionary = self._get_dictionary() - if vrefs is None or len(dictionary) == 0: + if not any([s.vref for s in sentences]) is None or len(dictionary) == 0: yield from self._translate_sentence_helper( pipeline, sentences, @@ -1528,7 +1529,7 @@ def _translate_sentences( produce_multiple_translations=produce_multiple_translations, ) else: - for batch, force_words in batch_sentences(sentences, vrefs, batch_size, dictionary): + for batch, force_words in batch_sentences(sentences, batch_size, dictionary): if force_words is None: force_words_ids = None else: @@ -1546,8 +1547,8 @@ def _translate_sentences( def _translate_sentence_helper( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, force_words_ids: List[List[List[int]]] = None, @@ -1558,8 +1559,6 @@ def _translate_sentence_helper( if produce_multiple_translations and num_drafts > 1: multiple_translations_method: str = self._config.infer.get("multiple_translations_method") - sentences = list(sentences) - if multiple_translations_method == "hybrid": beam_search_results: List[dict] = self._translate_with_beam_search( pipeline, @@ -1647,8 +1646,8 @@ def _flatten_tokenized_translations(self, pipeline_output) -> List[dict]: def _translate_with_beam_search( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, num_return_sequences: int = 1, @@ -1658,8 +1657,12 @@ def _translate_with_beam_search( if num_beams is None: num_beams = self._config.params.get("generation_num_beams") + pipeline.update_source_and_target_languages( + [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences], + [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], + ) translations = pipeline( - sentences, + [s.tokens if s.has_tokens() else s.text for s in sentences], num_beams=num_beams, num_return_sequences=num_return_sequences, force_words_ids=force_words_ids, @@ -1675,8 +1678,8 @@ def _translate_with_beam_search( def _translate_with_sampling( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, num_return_sequences: int = 1, @@ -1685,8 +1688,12 @@ def _translate_with_sampling( temperature: Optional[int] = self._config.infer.get("temperature") + pipeline.update_source_and_target_languages( + [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences], + [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], + ) translations = pipeline( - sentences, + [s.tokens if s.has_tokens() else s.text for s in sentences], do_sample=True, temperature=temperature, num_return_sequences=num_return_sequences, @@ -1703,8 +1710,8 @@ def _translate_with_sampling( def _translate_with_diverse_beam_search( self, - pipeline: TranslationPipeline, - sentences: Iterable[TSent], + pipeline: "SilTranslationPipeline", + sentences: Iterable[TranslationInputSentence], batch_size: int, return_tensors: bool, num_return_sequences: int = 1, @@ -1715,8 +1722,12 @@ def _translate_with_diverse_beam_search( num_beams = self._config.params.get("generation_num_beams") diversity_penalty: Optional[float] = self._config.infer.get("diversity_penalty") + pipeline.update_source_and_target_languages( + [self._config.data["lang_codes"].get(s.src_iso, s.src_iso) for s in sentences], + [self._config.data["lang_codes"].get(s.trg_iso, s.trg_iso) for s in sentences], + ) translations = pipeline( - sentences, + [s.tokens if s.has_tokens() else s.text for s in sentences], num_beams=num_beams, num_beam_groups=num_beams, num_return_sequences=num_return_sequences, @@ -1736,8 +1747,8 @@ def _create_inference_model( self, ckpt: Union[CheckpointType, str, int], tokenizer: PreTrainedTokenizer, - src_lang: str, - trg_lang: str, + src_lang: str = "", + trg_lang: str = "", ) -> PreTrainedModel: if self._config.model_dir.exists(): checkpoint_path, _ = self.get_checkpoint_path(ckpt) @@ -1855,6 +1866,13 @@ def token_to_id(self, token: str) -> int: def decode(self, *args, **kwargs): return self._wrapped_tokenizer.decode(*args, **kwargs) + def set_src_lang(self, src_lang: str): + self._wrapped_tokenizer.src_lang = src_lang + + @SpecialTokensMixin.eos_token_id.getter + def eos_token_id(self) -> int | None: + return self._wrapped_tokenizer.eos_token_id + class HuggingFaceTokenizer(Tokenizer): def __init__( @@ -1939,7 +1957,40 @@ def normalize(self, line: NormalizedString) -> None: self._tokenizer.normalize_normalized_string(line) -class SilTranslationPipeline(TranslationPipeline): +class SilTranslationPipeline(Text2TextGenerationPipeline): + def update_source_and_target_languages(self, src_langs: List[str], tgt_langs: List[str]) -> None: + self.src_langs = src_langs + self.src_index = 0 + self.tgt_langs = np.array( + [self.tokenizer.convert_tokens_to_ids(trg_lang) for trg_lang in tgt_langs], dtype=np.int64 + ) + self.tgt_index = 0 + + def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None): + if getattr(self.tokenizer, "_build_translation_inputs", None): + return self.tokenizer._build_translation_inputs( + *args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang + ) + else: + return self._parse_and_tokenize(*args, truncation=truncation) + + def _parse_and_tokenize(self, *args, truncation): + prefix = self.prefix if self.prefix is not None else "" + if isinstance(args[0], list): + raise ValueError("SilTranslationPipeline does not support batch tokenization") + + elif isinstance(args[0], str): + args = (prefix + args[0],) + self.tokenizer.set_src_lang(self.src_langs[self.src_index]) + self.src_index += 1 + else: + raise ValueError("SilTranslationPipeline only supports string inputs for tokenization") + inputs = self.tokenizer(*args, padding=False, truncation=truncation, return_tensors=self.framework) + # This is produced by tokenizers but is an invalid generate kwargs + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + return inputs + def _forward(self, model_inputs, **generate_kwargs): in_b, input_length = model_inputs["input_ids"].shape @@ -1950,6 +2001,17 @@ def _forward(self, model_inputs, **generate_kwargs): generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length) generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) + generate_kwargs["decoder_input_ids"] = torch.cat( + ( + torch.ones((in_b, 1), dtype=torch.long, device=model_inputs["input_ids"].device) + * self.tokenizer.eos_token_id, + torch.unsqueeze(torch.from_numpy(self.tgt_langs[self.tgt_index : self.tgt_index + in_b]), 1).to( + model_inputs["input_ids"].device + ), + ), + dim=1, + ) + self.tgt_index += in_b output = self.model.generate( **model_inputs, **generate_kwargs, diff --git a/silnlp/nmt/test.py b/silnlp/nmt/test.py index 74a66037..f169ee81 100644 --- a/silnlp/nmt/test.py +++ b/silnlp/nmt/test.py @@ -478,6 +478,7 @@ def test_checkpoint( vref_file_names: List[str] = [] source_file_names: List[str] = [] translation_file_names: List[str] = [] + gold_standard_target_file_names: List[str] = [] refs_patterns: List[str] = [] translation_detok_file_names: List[str] = [] translation_conf_file_names: List[str] = [] @@ -495,6 +496,7 @@ def test_checkpoint( refs_patterns.append("test.trg.detok*.txt") translation_detok_file_names.append(f"test.trg-predictions.detok.txt.{suffix_str}") translation_conf_file_names.append(f"test.trg-predictions.txt.{suffix_str}.confidences.tsv") + gold_standard_target_file_names.append("test.trg.txt") else: # test data is split into separate files for src_iso in sorted(config.test_src_isos): @@ -510,14 +512,17 @@ def test_checkpoint( refs_patterns.append(f"{prefix}.trg.detok*.txt") translation_detok_file_names.append(f"{prefix}.trg-predictions.detok.txt.{suffix_str}") translation_conf_file_names.append(f"{prefix}.trg-predictions.txt.{suffix_str}.confidences.tsv") + gold_standard_target_file_names.append(f"{prefix}.trg.txt") checkpoint_name = "averaged checkpoint" if step == -1 else f"checkpoint {step}" source_paths: List[Path] = [] + gold_standard_target_paths: List[Path] = [] vref_paths: Optional[List[Path]] = [] if config.has_scripture_data else None translation_paths: List[Path] = [] for i in range(len(translation_file_names)): predictions_path = config.exp_dir / translation_file_names[i] + gold_standard_target_paths.append(config.exp_dir / gold_standard_target_file_names[i]) if force_infer or not predictions_path.is_file(): source_paths.append(config.exp_dir / source_file_names[i]) translation_paths.append(predictions_path) @@ -527,6 +532,7 @@ def test_checkpoint( LOGGER.info(f"Inferencing {checkpoint_name}") model.translate_test_files( source_paths, + gold_standard_target_paths, translation_paths, produce_multiple_translations, save_confidences, diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index 2d936616..48eb41d2 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -2,17 +2,16 @@ import logging import os import time -from contextlib import AbstractContextManager from dataclasses import dataclass from pathlib import Path -from typing import Generator, Iterable, List, Optional, Tuple, Union +from typing import Generator, List, Optional, Tuple, Union -from machine.scripture import VerseRef, book_number_to_id, get_chapters +from machine.scripture import book_number_to_id, get_chapters from ..common.environment import SIL_NLP_ENV from ..common.paratext import book_file_name_digits, get_project_dir from ..common.postprocesser import PostprocessConfig, PostprocessHandler -from ..common.translator import SentenceTranslationGroup, Translator +from ..common.translator import SentenceTranslationGroup, TranslationInputSentence, Translator from ..common.utils import get_git_revision_hash, show_attrs from .clearml_connection import TAGS_LIST, SILClearML from .config import CheckpointType, Config, NMTModel @@ -27,15 +26,10 @@ def __init__(self, model: NMTModel, checkpoint: Union[CheckpointType, str, int]) def translate( self, - sentences: Iterable[str], - src_iso: str, - trg_iso: str, + sentences: List[TranslationInputSentence], produce_multiple_translations: bool = False, - vrefs: Optional[Iterable[VerseRef]] = None, ) -> Generator[SentenceTranslationGroup, None, None]: - yield from self._model.translate( - sentences, src_iso, trg_iso, produce_multiple_translations, vrefs, self._checkpoint - ) + yield from self._model.translate(sentences, produce_multiple_translations, self._checkpoint) def __exit__( self, exc_type, exc_value, traceback # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]