Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ def preprocess(self, text, **kwargs):
**kwargs,
)
else:
# Add speaker ID if needed and user didn't insert at start of text
if self.model.config.model_type == "csm":
text = [f"[0]{t}" if not t.startswith("[") else t for t in text]
if self.model.config.model_type == "dia":
text = [f"[S1] {t}" if not t.startswith("[") else t for t in text]
Comment on lines +175 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum really really not a fan of such hidden processing. This is where the abstraction of the pipeline (this does make sense if you want to interchange model id with simply changing the model) complicates things more than they simplify it ... but okay to keep here since there is already so much custom processing in the audio pipeline codes and that anyway.

Note we might remove in the future though if we find an good API to have specific kwargs for each TTS models and a convinient way to default them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, for example preset as we discussed here

output = preprocessor(text, **kwargs, return_tensors="pt")

return output
Expand All @@ -196,6 +201,12 @@ def _forward(self, model_inputs, **kwargs):
# ensure dict output to facilitate postprocessing
forward_params.update({"return_dict_in_generate": True})

if self.model.config.model_type in ["csm"]:
# NOTE (ebezzam): CSM does not have the audio tokenizer in the processor therefore `output_audio=True`
# needed for decoding to audio
if "output_audio" not in forward_params:
forward_params["output_audio"] = True

output = self.model.generate(**model_inputs, **forward_params)
else:
if len(generate_kwargs):
Expand Down
16 changes: 7 additions & 9 deletions tests/pipelines/test_pipelines_text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,26 +251,26 @@ def test_generative_model_kwargs(self):
@require_torch
def test_csm_model_pt(self):
speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device)
generate_kwargs = {"max_new_tokens": 10, "output_audio": True}
generate_kwargs = {"max_new_tokens": 10}
num_channels = 1 # model generates mono audio

outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs)
outputs = speech_generator("This is a test", generate_kwargs=generate_kwargs)
self.assertEqual(outputs["sampling_rate"], 24000)
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)
# ensure audio and not discrete codes
self.assertEqual(len(audio.shape), num_channels)

# test two examples side-by-side
outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs)
outputs = speech_generator(["This is a test", "This is a second test"], generate_kwargs=generate_kwargs)
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
self.assertEqual(len(audio[0].shape), num_channels)

# test batching
batch_size = 2
outputs = speech_generator(
["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size
["This is a test", "This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size
)
self.assertEqual(len(outputs), batch_size)
audio = [output["audio"] for output in outputs]
Expand All @@ -284,9 +284,7 @@ def test_dia_model(self):
generate_kwargs = {"max_new_tokens": 20}
num_channels = 1 # model generates mono audio

outputs = speech_generator(
"[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs
)
outputs = speech_generator("Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs)
self.assertEqual(outputs["sampling_rate"], 44100)
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)
Expand All @@ -295,7 +293,7 @@ def test_dia_model(self):

# test two examples side-by-side
outputs = speech_generator(
["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."],
["Dia is an open weights text to dialogue model.", "This is a second example."],
generate_kwargs=generate_kwargs,
)
audio = [output["audio"] for output in outputs]
Expand All @@ -305,7 +303,7 @@ def test_dia_model(self):
# test batching
batch_size = 2
outputs = speech_generator(
["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."],
["Dia is an open weights text to dialogue model.", "This is a second example."],
generate_kwargs=generate_kwargs,
batch_size=2,
)
Expand Down