@@ -251,26 +251,26 @@ def test_generative_model_kwargs(self):
251251 @require_torch
252252 def test_csm_model_pt (self ):
253253 speech_generator = pipeline (task = "text-to-audio" , model = "sesame/csm-1b" , device = torch_device )
254- generate_kwargs = {"max_new_tokens" : 10 , "output_audio" : True }
254+ generate_kwargs = {"max_new_tokens" : 10 }
255255 num_channels = 1 # model generates mono audio
256256
257- outputs = speech_generator ("[0] This is a test" , generate_kwargs = generate_kwargs )
257+ outputs = speech_generator ("This is a test" , generate_kwargs = generate_kwargs )
258258 self .assertEqual (outputs ["sampling_rate" ], 24000 )
259259 audio = outputs ["audio" ]
260260 self .assertEqual (ANY (np .ndarray ), audio )
261261 # ensure audio and not discrete codes
262262 self .assertEqual (len (audio .shape ), num_channels )
263263
264264 # test two examples side-by-side
265- outputs = speech_generator (["[0] This is a test" , "[0] This is a second test" ], generate_kwargs = generate_kwargs )
265+ outputs = speech_generator (["This is a test" , "This is a second test" ], generate_kwargs = generate_kwargs )
266266 audio = [output ["audio" ] for output in outputs ]
267267 self .assertEqual ([ANY (np .ndarray ), ANY (np .ndarray )], audio )
268268 self .assertEqual (len (audio [0 ].shape ), num_channels )
269269
270270 # test batching
271271 batch_size = 2
272272 outputs = speech_generator (
273- ["[0] This is a test" , "[0] This is a second test" ], generate_kwargs = generate_kwargs , batch_size = batch_size
273+ ["This is a test" , "This is a second test" ], generate_kwargs = generate_kwargs , batch_size = batch_size
274274 )
275275 self .assertEqual (len (outputs ), batch_size )
276276 audio = [output ["audio" ] for output in outputs ]
@@ -284,9 +284,7 @@ def test_dia_model(self):
284284 generate_kwargs = {"max_new_tokens" : 20 }
285285 num_channels = 1 # model generates mono audio
286286
287- outputs = speech_generator (
288- "[S1] Dia is an open weights text to dialogue model." , generate_kwargs = generate_kwargs
289- )
287+ outputs = speech_generator ("Dia is an open weights text to dialogue model." , generate_kwargs = generate_kwargs )
290288 self .assertEqual (outputs ["sampling_rate" ], 44100 )
291289 audio = outputs ["audio" ]
292290 self .assertEqual (ANY (np .ndarray ), audio )
@@ -295,7 +293,7 @@ def test_dia_model(self):
295293
296294 # test two examples side-by-side
297295 outputs = speech_generator (
298- ["[S1] Dia is an open weights text to dialogue model." , "[S2] This is a second example." ],
296+ ["Dia is an open weights text to dialogue model." , "This is a second example." ],
299297 generate_kwargs = generate_kwargs ,
300298 )
301299 audio = [output ["audio" ] for output in outputs ]
@@ -305,7 +303,7 @@ def test_dia_model(self):
305303 # test batching
306304 batch_size = 2
307305 outputs = speech_generator (
308- ["[S1] Dia is an open weights text to dialogue model." , "[S2] This is a second example." ],
306+ ["Dia is an open weights text to dialogue model." , "This is a second example." ],
309307 generate_kwargs = generate_kwargs ,
310308 batch_size = 2 ,
311309 )
0 commit comments