Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions sonar/inference_pipelines/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
def predict(
self,
input: Union[Path, Sequence[str]],
source_lang: str,
source_lang: Union[str, Sequence[str]],
batch_size: Optional[int] = 5,
batch_max_tokens: Optional[int] = None,
max_seq_len: Optional[int] = None,
Expand All @@ -196,9 +196,24 @@ def predict(
if batch_size is not None and batch_size <= 0:
raise ValueError("`batch_size` should be strictly positive")

tokenizer_encoder = self.tokenizer.create_encoder(
lang=source_lang, device=self.device
)
def encode_fn(x: Union[str, tuple[str, str]]) -> torch.Tensor:
if isinstance(source_lang, str):
assert isinstance(x, str)
tokenizer_encoder = self.tokenizer.create_encoder(
lang=source_lang,
device=self.device,
)
return tokenizer_encoder(x)
else:
# Multiple languages
assert isinstance(x, tuple)
text, lang = x
tokenizer_encoder = self.tokenizer.create_encoder(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be faster if you create encoders only once per language (e.g. by caching them with lru_cache or something like that).
Ideally, you should benchmark this.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with functools.lru_cache(max_size=32), batch_size 10, the decoder took longer, how about for now we don't use cache?

image

lang=lang,
device=self.device,
)
return tokenizer_encoder(text)

model_max_len = cast(int | None, self.model.encoder_frontend.pos_encoder.max_seq_len) # type: ignore[union-attr]
if max_seq_len is None:
max_seq_len = model_max_len
Expand All @@ -221,15 +236,37 @@ def truncate(x: torch.Tensor) -> torch.Tensor:
if isinstance(input, (str, Path)):
pipeline_builder = read_text(Path(input))
sorting_index = None
if not isinstance(source_lang, str):
raise ValueError(
"If input is a file, source_lang must be a single string."
)
else:
# so it should a list
sorting_index = torch.argsort(torch.tensor(list(map(len, input))))
pipeline_builder = read_sequence(list(sorting_index.cpu())).map(
input.__getitem__
# input is a list
if isinstance(source_lang, str):
items = input
else:
if len(input) != len(source_lang):
raise ValueError("Length of input and source_lang must match.")
items = list(zip(input, source_lang)) # type: ignore[arg-type]
sorting_index = torch.argsort(
torch.tensor(
list(
map(
lambda x: (
len(x[0])
if not isinstance(source_lang, str)
else len(x)
),
items,
)
)
)
)
sorted_items = [items[i] for i in sorting_index.tolist()]
pipeline_builder = read_sequence(sorted_items)

pipeline: Iterable = (
pipeline_builder.map(tokenizer_encoder)
pipeline_builder.map(encode_fn)
.map(truncate)
.dynamic_bucket(
batch_max_tokens or 2**31,
Expand Down Expand Up @@ -306,7 +343,7 @@ def __init__(
def predict(
self,
inputs: torch.Tensor,
target_lang: str,
target_lang: Union[str, Sequence[str]],
batch_size: int = 5,
progress_bar: bool = False,
sampler: Optional[Sampler] = None,
Expand All @@ -319,25 +356,41 @@ def predict(
else:
generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs)

converter = SequenceToTextConverter(
generator,
self.tokenizer,
task="translation",
target_lang=target_lang,
)
if isinstance(target_lang, str):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid having two big separate branches of code, maybe we simply convert the case of single language code into the case of sequence in the very beginning, and then proceed with the same translation function?

Copy link
Copy Markdown
Author

@jasonrichdarmawan jasonrichdarmawan Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this?

target_lang = [target_lang] * len(inputs)

if len(target_lang) != len(inputs):
raise ValueError(
"input and target_lang must have the same length for multi-language decoding."
)

def _do_translate(src_tensors: List[torch.Tensor]) -> List[str]:
def _do_translate(x: tuple[torch.Tensor, str]):
tensor, lang = x
converter = SequenceToTextConverter(
generator,
self.tokenizer,
task="translation",
target_lang=lang,
)
texts, _ = converter.batch_convert(
torch.stack(src_tensors).to(self.device), None
torch.stack([tensor]).to(self.device),
None,
)
return texts

pipeline: Iterable = (
read_sequence(list(inputs))
.bucket(batch_size)
read_sequence(
list(
zip(
list(inputs),
list(target_lang),
)
)
)
.map(_do_translate)
.and_return()
)

if progress_bar:
pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size)
with precision_context(self.model.dtype):
Expand Down