-
Notifications
You must be signed in to change notification settings - Fork 100
feat: add multilanguage support to text2vec and vec2text #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -174,7 +174,7 @@ def __init__( | |
| def predict( | ||
| self, | ||
| input: Union[Path, Sequence[str]], | ||
| source_lang: str, | ||
| source_lang: Union[str, Sequence[str]], | ||
| batch_size: Optional[int] = 5, | ||
| batch_max_tokens: Optional[int] = None, | ||
| max_seq_len: Optional[int] = None, | ||
|
|
@@ -196,9 +196,24 @@ def predict( | |
| if batch_size is not None and batch_size <= 0: | ||
| raise ValueError("`batch_size` should be strictly positive") | ||
|
|
||
| tokenizer_encoder = self.tokenizer.create_encoder( | ||
| lang=source_lang, device=self.device | ||
| ) | ||
| def encode_fn(x: Union[str, tuple[str, str]]) -> torch.Tensor: | ||
| if isinstance(source_lang, str): | ||
| assert isinstance(x, str) | ||
| tokenizer_encoder = self.tokenizer.create_encoder( | ||
| lang=source_lang, | ||
| device=self.device, | ||
| ) | ||
| return tokenizer_encoder(x) | ||
| else: | ||
| # Multiple languages | ||
| assert isinstance(x, tuple) | ||
| text, lang = x | ||
| tokenizer_encoder = self.tokenizer.create_encoder( | ||
| lang=lang, | ||
| device=self.device, | ||
| ) | ||
| return tokenizer_encoder(text) | ||
|
|
||
| model_max_len = cast(int | None, self.model.encoder_frontend.pos_encoder.max_seq_len) # type: ignore[union-attr] | ||
| if max_seq_len is None: | ||
| max_seq_len = model_max_len | ||
|
|
@@ -221,15 +236,37 @@ def truncate(x: torch.Tensor) -> torch.Tensor: | |
| if isinstance(input, (str, Path)): | ||
| pipeline_builder = read_text(Path(input)) | ||
| sorting_index = None | ||
| if not isinstance(source_lang, str): | ||
| raise ValueError( | ||
| "If input is a file, source_lang must be a single string." | ||
| ) | ||
| else: | ||
| # so it should a list | ||
| sorting_index = torch.argsort(torch.tensor(list(map(len, input)))) | ||
| pipeline_builder = read_sequence(list(sorting_index.cpu())).map( | ||
| input.__getitem__ | ||
| # input is a list | ||
| if isinstance(source_lang, str): | ||
| items = input | ||
| else: | ||
| if len(input) != len(source_lang): | ||
| raise ValueError("Length of input and source_lang must match.") | ||
| items = list(zip(input, source_lang)) # type: ignore[arg-type] | ||
| sorting_index = torch.argsort( | ||
| torch.tensor( | ||
| list( | ||
| map( | ||
| lambda x: ( | ||
| len(x[0]) | ||
| if not isinstance(source_lang, str) | ||
| else len(x) | ||
| ), | ||
| items, | ||
| ) | ||
| ) | ||
| ) | ||
| ) | ||
| sorted_items = [items[i] for i in sorting_index.tolist()] | ||
| pipeline_builder = read_sequence(sorted_items) | ||
|
|
||
| pipeline: Iterable = ( | ||
| pipeline_builder.map(tokenizer_encoder) | ||
| pipeline_builder.map(encode_fn) | ||
| .map(truncate) | ||
| .dynamic_bucket( | ||
| batch_max_tokens or 2**31, | ||
|
|
@@ -306,7 +343,7 @@ def __init__( | |
| def predict( | ||
| self, | ||
| inputs: torch.Tensor, | ||
| target_lang: str, | ||
| target_lang: Union[str, Sequence[str]], | ||
| batch_size: int = 5, | ||
| progress_bar: bool = False, | ||
| sampler: Optional[Sampler] = None, | ||
|
|
@@ -319,25 +356,41 @@ def predict( | |
| else: | ||
| generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) | ||
|
|
||
| converter = SequenceToTextConverter( | ||
| generator, | ||
| self.tokenizer, | ||
| task="translation", | ||
| target_lang=target_lang, | ||
| ) | ||
| if isinstance(target_lang, str): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid having two big separate branches of code, maybe we simply convert the case of single language code into the case of sequence in the very beginning, and then proceed with the same translation function?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about this? |
||
| target_lang = [target_lang] * len(inputs) | ||
|
|
||
| if len(target_lang) != len(inputs): | ||
| raise ValueError( | ||
| "input and target_lang must have the same length for multi-language decoding." | ||
| ) | ||
|
|
||
| def _do_translate(src_tensors: List[torch.Tensor]) -> List[str]: | ||
| def _do_translate(x: tuple[torch.Tensor, str]): | ||
| tensor, lang = x | ||
| converter = SequenceToTextConverter( | ||
| generator, | ||
| self.tokenizer, | ||
| task="translation", | ||
| target_lang=lang, | ||
| ) | ||
| texts, _ = converter.batch_convert( | ||
| torch.stack(src_tensors).to(self.device), None | ||
| torch.stack([tensor]).to(self.device), | ||
| None, | ||
| ) | ||
| return texts | ||
|
|
||
| pipeline: Iterable = ( | ||
| read_sequence(list(inputs)) | ||
| .bucket(batch_size) | ||
| read_sequence( | ||
| list( | ||
| zip( | ||
| list(inputs), | ||
| list(target_lang), | ||
| ) | ||
| ) | ||
| ) | ||
| .map(_do_translate) | ||
| .and_return() | ||
| ) | ||
|
|
||
| if progress_bar: | ||
| pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size) | ||
| with precision_context(self.model.dtype): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be faster if you create encoders only once per language (e.g. by caching them with
lru_cacheor something like that).Ideally, you should benchmark this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with
functools.lru_cache(max_size=32), batch_size 10, the decoder took longer, how about for now we don't use cache?