diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1edcfe8..18820b3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,38 +17,47 @@ jobs: pytorch-version: 1.10.1 numpy-requirement: "'numpy<2'" tokenizers-requirement: "'tokenizers<=0.20.3'" + transformers-requirement: "'transformers==4.46.3'" - python-version: '3.8' pytorch-version: 1.13.1 numpy-requirement: "'numpy<2'" tokenizers-requirement: "'tokenizers<=0.20.3'" + transformers-requirement: "'transformers==4.46.3'" - python-version: '3.8' pytorch-version: 2.0.1 numpy-requirement: "'numpy<2'" tokenizers-requirement: "'tokenizers<=0.20.3'" + transformers-requirement: "'transformers==4.46.3'" - python-version: '3.9' pytorch-version: 2.1.2 numpy-requirement: "'numpy<2'" tokenizers-requirement: "'tokenizers'" + transformers-requirement: "'transformers'" - python-version: '3.10' pytorch-version: 2.2.2 numpy-requirement: "'numpy<2'" tokenizers-requirement: "'tokenizers'" + transformers-requirement: "'transformers'" - python-version: '3.11' pytorch-version: 2.3.1 numpy-requirement: "'numpy'" tokenizers-requirement: "'tokenizers'" + transformers-requirement: "'transformers'" - python-version: '3.12' pytorch-version: 2.4.1 numpy-requirement: "'numpy'" tokenizers-requirement: "'tokenizers'" + transformers-requirement: "'transformers'" - python-version: '3.12' pytorch-version: 2.5.0 numpy-requirement: "'numpy'" tokenizers-requirement: "'tokenizers'" + transformers-requirement: "'transformers'" - python-version: '3.12' pytorch-version: 2.6.0 numpy-requirement: "'numpy'" tokenizers-requirement: "'tokenizers'" + transformers-requirement: "'transformers'" steps: - uses: conda-incubator/setup-miniconda@v3 - run: conda install -n test ffmpeg python=${{ matrix.python-version }} @@ -63,7 +72,7 @@ jobs: - run: python test/test_transcribe.py load_faster_whisper - run: python test/test_align.py load_faster_whisper - run: python test/test_refine.py load_faster_whisper - - run: pip3 install .["hf"] 'transformers<=4.46.3' + - run: pip3 install .["hf"] ${{ matrix.transformers-requirement }} - run: python test/test_transcribe.py load_hf_whisper - run: python test/test_align.py load_hf_whisper - run: python test/test_refine.py load_hf_whisper diff --git a/setup.py b/setup.py index 30b9f51..afd6f4b 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def read_me() -> str: "torch", "torchaudio", "tqdm", - "openai-whisper>=20230314,<=20240930" + "openai-whisper>=20250625" ], extras_require={ "fw": [ diff --git a/stable_whisper/whisper_compatibility.py b/stable_whisper/whisper_compatibility.py index 8cd1118..ca9efa9 100644 --- a/stable_whisper/whisper_compatibility.py +++ b/stable_whisper/whisper_compatibility.py @@ -16,6 +16,7 @@ '20231117', '20240927', '20240930', + '20250625', ) _required_whisper_ver = _COMPATIBLE_WHISPER_VERSIONS[-1] diff --git a/stable_whisper/whisper_word_level/hf_whisper.py b/stable_whisper/whisper_word_level/hf_whisper.py index 8de045e..d00671c 100644 --- a/stable_whisper/whisper_word_level/hf_whisper.py +++ b/stable_whisper/whisper_word_level/hf_whisper.py @@ -63,38 +63,39 @@ def get_device(device: str = None) -> str: def load_hf_pipe(model_name: str, device: str = None, flash: bool = False, **pipeline_kwargs): from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline + from transformers.configuration_utils import PretrainedConfig device = get_device(device) is_cpu = (device if isinstance(device, str) else getattr(device, 'type', None)) == 'cpu' dtype = torch.float32 if is_cpu or not torch.cuda.is_available() else torch.float16 model_id = HF_MODELS.get(model_name, model_name) + + if flash: + config = PretrainedConfig( + attn_implementation="flash_attention_2", + ) + else: + config = None + model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=dtype, low_cpu_mem_usage=True, use_safetensors=True, - use_flash_attention_2=flash + config=config ).to(device) processor = AutoProcessor.from_pretrained(model_id) - if not flash: - try: - model = model.to_bettertransformer() - except (ValueError, ImportError) as e: - import warnings - warnings.warn( - f'Failed convert model to BetterTransformer due to: {e}' - ) - final_pipe_kwargs = dict( task="automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, - chunk_length_s=30, + # chunk_length_s=30, torch_dtype=dtype, device=device, + return_language=True ) final_pipe_kwargs.update(**pipeline_kwargs) pipe = pipeline(**final_pipe_kwargs) @@ -106,6 +107,7 @@ class WhisperHF: def __init__(self, model_name: str, device: str = None, flash: bool = False, pipeline=None, **pipeline_kwargs): self._model_name = model_name + pipeline_kwargs['return_language'] = True self._pipe = load_hf_pipe(self._model_name, device, flash=flash, **pipeline_kwargs) if pipeline is None \ else pipeline self._model_name = getattr(self._pipe.model, 'name_or_path', self._model_name) @@ -154,6 +156,45 @@ def _inner_transcribe( language = 'en' if not language and result and 'language' in result[0]: language = result[0]['language'] + if not language and hasattr(output, 'get') and 'detected_language' in output: + language = output['detected_language'] + if not language: + # HF Pipelines have broken language detection. + # Manually detect language by generating tokens from the first 10 seconds of the audio. + try: + import torch + sample_audio = audio[:int(self.sampling_rate * 10)] # Use first 10 seconds + inputs = self._pipe.feature_extractor(sample_audio, sampling_rate=self.sampling_rate, return_tensors="pt") + + # Ensure input features match model dtype and device + model_dtype = next(self._pipe.model.parameters()).dtype + model_device = next(self._pipe.model.parameters()).device + inputs.input_features = inputs.input_features.to(dtype=model_dtype, device=model_device) + + # Generate with minimal tokens to detect language + with torch.no_grad(): + generated_ids = self._pipe.model.generate( + inputs.input_features, + max_new_tokens=10, + do_sample=False, + output_scores=True, + return_dict_in_generate=True + ) + + # Decode the tokens to extract language information + tokens = self._pipe.tokenizer.batch_decode(generated_ids.sequences, skip_special_tokens=False)[0] + + # Extract language token (format: <|en|>, <|fr|>, etc.) + import re + lang_match = re.search(r'<\|(\w{2})\|>', tokens) + if lang_match: + language = lang_match.group(1) + else: + language = None + + except Exception as e: + print(f'Error detecting language: {e}') + language = None if verbose is not None: print(f'Transcription completed.')