diff --git a/offline_inference.py b/offline_inference.py index b5dd708..861a310 100644 --- a/offline_inference.py +++ b/offline_inference.py @@ -8,12 +8,16 @@ def main(): parser.add_argument( "--model-path", type=str, required=True, help="Base path for model files" ) + parser.add_argument( + "--load-in-8bit", action="store_true", help="Base path for model files" + ) args = parser.parse_args() model = StepAudio( tokenizer_path=f"{args.model_path}/Step-Audio-Tokenizer", tts_path=f"{args.model_path}/Step-Audio-TTS-3B", llm_path=f"{args.model_path}/Step-Audio-Chat", + load_in_8bit=args.load_in_8bit ) # example for text input diff --git a/requirements-vllm.txt b/requirements-vllm.txt index 78b53c2..0961b68 100644 --- a/requirements-vllm.txt +++ b/requirements-vllm.txt @@ -19,4 +19,5 @@ sentencepiece funasr>=1.1.3 protobuf==5.29.3 gradio>=5.16.0 -vllm==0.7.2 \ No newline at end of file +vllm==0.7.2 +bitsandbytes \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2aa7689..2b13cd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ sentencepiece funasr>=1.1.3 protobuf==5.29.3 gradio>=5.16.0 +bitsandbytes \ No newline at end of file diff --git a/stepaudio.py b/stepaudio.py index bf80b83..36e7ed4 100644 --- a/stepaudio.py +++ b/stepaudio.py @@ -2,7 +2,7 @@ import torch import torchaudio -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from tokenizer import StepAudioTokenizer from tts import StepAudioTTS @@ -10,10 +10,14 @@ class StepAudio: - def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str): + def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str, load_in_8bit: bool=False): load_optimus_ths_lib(os.path.join(llm_path, 'lib')) + q_config = None + if load_in_8bit: + q_config = BitsAndBytesConfig(load_in_8bit=True) + print(f"load in 8bit") self.llm_tokenizer = AutoTokenizer.from_pretrained( - llm_path, trust_remote_code=True + llm_path, trust_remote_code=True, quantization_config=q_config, ) self.encoder = StepAudioTokenizer(tokenizer_path) self.decoder = StepAudioTTS(tts_path, self.encoder) @@ -22,6 +26,7 @@ def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str): torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, + quantization_config=q_config, ) def __call__( @@ -73,26 +78,3 @@ def apply_chat_template(self, messages: list): text_with_audio += "<|BOT|>assistant\n" return text_with_audio - -if __name__ == "__main__": - model = StepAudio( - encoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-encoder", - decoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-decoder", - llm_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-v18", - ) - - text, audio, sr = model( - [{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}], - "Tingting", - ) - torchaudio.save("output/output_e2e_tqta.wav", audio, sr) - text, audio, sr = model( - [ - { - "role": "user", - "content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"}, - } - ], - "Tingting", - ) - torchaudio.save("output/output_e2e_aqta.wav", audio, sr)