diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index be7a9b9bc0c8..9331d3365a3e 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -172,6 +172,11 @@ def preprocess(self, text, **kwargs): **kwargs, ) else: + # Add speaker ID if needed and user didn't insert at start of text + if self.model.config.model_type == "csm": + text = [f"[0]{t}" if not t.startswith("[") else t for t in text] + if self.model.config.model_type == "dia": + text = [f"[S1] {t}" if not t.startswith("[") else t for t in text] output = preprocessor(text, **kwargs, return_tensors="pt") return output @@ -196,6 +201,12 @@ def _forward(self, model_inputs, **kwargs): # ensure dict output to facilitate postprocessing forward_params.update({"return_dict_in_generate": True}) + if self.model.config.model_type in ["csm"]: + # NOTE (ebezzam): CSM does not have the audio tokenizer in the processor therefore `output_audio=True` + # needed for decoding to audio + if "output_audio" not in forward_params: + forward_params["output_audio"] = True + output = self.model.generate(**model_inputs, **forward_params) else: if len(generate_kwargs): diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index 19c7570e9e9d..8f8462b2288a 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -251,10 +251,10 @@ def test_generative_model_kwargs(self): @require_torch def test_csm_model_pt(self): speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device) - generate_kwargs = {"max_new_tokens": 10, "output_audio": True} + generate_kwargs = {"max_new_tokens": 10} num_channels = 1 # model generates mono audio - outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs) + outputs = speech_generator("This is a test", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 24000) audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) @@ -262,7 +262,7 @@ def test_csm_model_pt(self): self.assertEqual(len(audio.shape), num_channels) # test two examples side-by-side - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs) + outputs = speech_generator(["This is a test", "This is a second test"], generate_kwargs=generate_kwargs) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual(len(audio[0].shape), num_channels) @@ -270,7 +270,7 @@ def test_csm_model_pt(self): # test batching batch_size = 2 outputs = speech_generator( - ["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size + ["This is a test", "This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size ) self.assertEqual(len(outputs), batch_size) audio = [output["audio"] for output in outputs] @@ -284,9 +284,7 @@ def test_dia_model(self): generate_kwargs = {"max_new_tokens": 20} num_channels = 1 # model generates mono audio - outputs = speech_generator( - "[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs - ) + outputs = speech_generator("Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 44100) audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) @@ -295,7 +293,7 @@ def test_dia_model(self): # test two examples side-by-side outputs = speech_generator( - ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], + ["Dia is an open weights text to dialogue model.", "This is a second example."], generate_kwargs=generate_kwargs, ) audio = [output["audio"] for output in outputs] @@ -305,7 +303,7 @@ def test_dia_model(self): # test batching batch_size = 2 outputs = speech_generator( - ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], + ["Dia is an open weights text to dialogue model.", "This is a second example."], generate_kwargs=generate_kwargs, batch_size=2, )