Skip to content

Commit 26dbe64

Browse files
authored
Convenient default behavior for pipeline TTS usage. (#42473)
1 parent cac0a28 commit 26dbe64

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

src/transformers/pipelines/text_to_audio.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ def preprocess(self, text, **kwargs):
172172
**kwargs,
173173
)
174174
else:
175+
# Add speaker ID if needed and user didn't insert at start of text
176+
if self.model.config.model_type == "csm":
177+
text = [f"[0]{t}" if not t.startswith("[") else t for t in text]
178+
if self.model.config.model_type == "dia":
179+
text = [f"[S1] {t}" if not t.startswith("[") else t for t in text]
175180
output = preprocessor(text, **kwargs, return_tensors="pt")
176181

177182
return output
@@ -196,6 +201,12 @@ def _forward(self, model_inputs, **kwargs):
196201
# ensure dict output to facilitate postprocessing
197202
forward_params.update({"return_dict_in_generate": True})
198203

204+
if self.model.config.model_type in ["csm"]:
205+
# NOTE (ebezzam): CSM does not have the audio tokenizer in the processor therefore `output_audio=True`
206+
# needed for decoding to audio
207+
if "output_audio" not in forward_params:
208+
forward_params["output_audio"] = True
209+
199210
output = self.model.generate(**model_inputs, **forward_params)
200211
else:
201212
if len(generate_kwargs):

tests/pipelines/test_pipelines_text_to_audio.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,26 +251,26 @@ def test_generative_model_kwargs(self):
251251
@require_torch
252252
def test_csm_model_pt(self):
253253
speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device)
254-
generate_kwargs = {"max_new_tokens": 10, "output_audio": True}
254+
generate_kwargs = {"max_new_tokens": 10}
255255
num_channels = 1 # model generates mono audio
256256

257-
outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs)
257+
outputs = speech_generator("This is a test", generate_kwargs=generate_kwargs)
258258
self.assertEqual(outputs["sampling_rate"], 24000)
259259
audio = outputs["audio"]
260260
self.assertEqual(ANY(np.ndarray), audio)
261261
# ensure audio and not discrete codes
262262
self.assertEqual(len(audio.shape), num_channels)
263263

264264
# test two examples side-by-side
265-
outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs)
265+
outputs = speech_generator(["This is a test", "This is a second test"], generate_kwargs=generate_kwargs)
266266
audio = [output["audio"] for output in outputs]
267267
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
268268
self.assertEqual(len(audio[0].shape), num_channels)
269269

270270
# test batching
271271
batch_size = 2
272272
outputs = speech_generator(
273-
["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size
273+
["This is a test", "This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size
274274
)
275275
self.assertEqual(len(outputs), batch_size)
276276
audio = [output["audio"] for output in outputs]
@@ -284,9 +284,7 @@ def test_dia_model(self):
284284
generate_kwargs = {"max_new_tokens": 20}
285285
num_channels = 1 # model generates mono audio
286286

287-
outputs = speech_generator(
288-
"[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs
289-
)
287+
outputs = speech_generator("Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs)
290288
self.assertEqual(outputs["sampling_rate"], 44100)
291289
audio = outputs["audio"]
292290
self.assertEqual(ANY(np.ndarray), audio)
@@ -295,7 +293,7 @@ def test_dia_model(self):
295293

296294
# test two examples side-by-side
297295
outputs = speech_generator(
298-
["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."],
296+
["Dia is an open weights text to dialogue model.", "This is a second example."],
299297
generate_kwargs=generate_kwargs,
300298
)
301299
audio = [output["audio"] for output in outputs]
@@ -305,7 +303,7 @@ def test_dia_model(self):
305303
# test batching
306304
batch_size = 2
307305
outputs = speech_generator(
308-
["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."],
306+
["Dia is an open weights text to dialogue model.", "This is a second example."],
309307
generate_kwargs=generate_kwargs,
310308
batch_size=2,
311309
)

0 commit comments

Comments
 (0)