Summary
Inference fails with an IndexError when using a LoRA-fine-tuned Soprano model (merged or adapter-based) via SopranoTTS.
This appears to be caused by incorrect assumptions about the shape of hidden_states inside infer_batch.
Environment
- OS: Linux (Colab / Kaggle)
- Python: 3.12
- torch:
2.9.0
- transformers:
4.57.6
- peft:
0.18.1
- soprano: latest from pip
- Base model:
ekwek/Soprano-1.1-80M
- Fine-tuning: LoRA (PEFT), language adaptation (non-English)
What I did
- Fine-tuned
ekwek/Soprano-1.1-80M using LoRA for a different language.
- Merged LoRA weights back into the base model:
merged_model = model.merge_and_unload()
merged_model.save_pretrained("merged")
tokenizer.save_pretrained("merged")
- Copied decoder.pth from the base Soprano repo into
merged/.
- Ran inference using SopranoTTS.
Error Trace
IndexError Traceback (most recent call last)
[/tmp/ipython-input-2296956188.py](https://localhost:8080/#) in <cell line: 0>()
3 model = SopranoTTS(backend='auto', model_path='/content/merged-fixed', device='auto', cache_size_mb=100, decoder_batch_size=1)
4
----> 5 out = model.infer("some non enligh text")
6 print(out)
7 with open("audio.wav", 'wb+') as f:
1 frames
[/usr/local/lib/python3.12/dist-packages/soprano/tts.py](https://localhost:8080/#) in infer_batch(self, texts, out_dir, top_p, temperature, repetition_penalty, retries)
181 batch_hidden_states.append(torch.cat([
182 torch.zeros((1, 512, lengths[0]-lengths[i]), device=self.device),
--> 183 hidden_states[idx+i].unsqueeze(0).transpose(1,2).to(self.device).to(torch.float32),
184 ], dim=2))
185 batch_hidden_states = torch.cat(batch_hidden_states)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
Summary
Inference fails with an
IndexErrorwhen using a LoRA-fine-tuned Soprano model (merged or adapter-based) viaSopranoTTS.This appears to be caused by incorrect assumptions about the shape of
hidden_statesinsideinfer_batch.Environment
2.9.04.57.60.18.1ekwek/Soprano-1.1-80MWhat I did
ekwek/Soprano-1.1-80Musing LoRA for a different language.merged/.Error Trace