From 6a7c815ba5401eace9edc7651fcaa953b601ad57 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Sun, 29 Jun 2025 15:47:05 +0000 Subject: [PATCH 1/5] finetuning granite speech, initial commit --- notebooks/en/_toctree.yml | 2 + notebooks/en/fine_tuning_granite_speech.ipynb | 560 ++++++++++++++++++ notebooks/en/index.md | 2 +- 3 files changed, 563 insertions(+), 1 deletion(-) create mode 100644 notebooks/en/fine_tuning_granite_speech.ipynb diff --git a/notebooks/en/_toctree.yml b/notebooks/en/_toctree.yml index e46efc44..68ae24b8 100644 --- a/notebooks/en/_toctree.yml +++ b/notebooks/en/_toctree.yml @@ -124,6 +124,8 @@ title: Structured Generation from Images or Documents Using Vision Language Models - local: fine_tuning_granite_vision_sft_trl title: Fine-tuning Granite Vision with TRL + - local: fine_tuning_granite_speech + title: Fine-tuning Granite Speech - title: Search Recipes isExpanded: false diff --git a/notebooks/en/fine_tuning_granite_speech.ipynb b/notebooks/en/fine_tuning_granite_speech.ipynb new file mode 100644 index 00000000..8a0bddd7 --- /dev/null +++ b/notebooks/en/fine_tuning_granite_speech.ipynb @@ -0,0 +1,560 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b052096b", + "metadata": {}, + "source": [ + "Written by [Avihu Dekel](https://huggingface.co/Avihu).\n", + "\n", + "# Finetuning Granite Speech\n", + "\n", + "[Granite speech](https://huggingface.co/collections/ibm-granite/granite-speech-67e45da088d5092ff6b901c7) is a family of powerful speech models, that excel in speech recognition and speech translation. \n", + "While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (in June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it is easy to finetune on unseen data or add new tasks.\n", + "\n", + "In this example, we'll show how to:\n", + "1. Run inference with Granite Speech\n", + "2. Evaluate the predictions\n", + "3. Finetune the model with new data.\n", + "Specifically, we'll finetune Granite Speech 2B on [GigaSpeech](https://huggingface.co/datasets/speechcolab/gigaspeech), a large spontaneous conversational dataset which was not included in the model's training. \n" + ] + }, + { + "cell_type": "markdown", + "id": "9f59dffb", + "metadata": {}, + "source": [ + "## Installing packages\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d985dae1", + "metadata": {}, + "outputs": [], + "source": [ + "# todo!" + ] + }, + { + "cell_type": "markdown", + "id": "d9a0fe29", + "metadata": {}, + "source": [ + "## Dataset loading and preprocessing\n", + "We'll start with downloading the data. \n", + "We selected the smallest subset of GigaSpeech, and filtered the train/val/tests sets to be extremely small." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8f1bc8c0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/proj/mmllm/miniforge/envs/mma/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "\"AS THEY'RE LEAVING CAN KASH PULL ZAHRA ASIDE REALLY QUICKLY \"" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_dataset, Audio\n", + "# loading small portions for speed\n", + "dataset = load_dataset(\"speechcolab/gigaspeech\", \"xs\")\n", + "train_dataset = dataset[\"train\"].take(5000)\n", + "val_dataset = dataset[\"validation\"].take(200)\n", + "test_dataset = dataset[\"test\"].take(200)\n", + "\n", + "train_dataset[0][\"text\"]" + ] + }, + { + "cell_type": "markdown", + "id": "4381a33a", + "metadata": {}, + "source": [ + "## Loading the model and processor" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "df6b3391", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 21.64it/s]\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers.models.granite_speech import GraniteSpeechForConditionalGeneration, GraniteSpeechProcessor\n", + "model_name = \"ibm-granite/granite-speech-3.3-2b\"\n", + "processor = GraniteSpeechProcessor.from_pretrained(model_name)\n", + "model = GraniteSpeechForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n" + ] + }, + { + "cell_type": "markdown", + "id": "aa742674", + "metadata": {}, + "source": [ + "## Data preprocessing\n", + "Let's continue with data processing:\n", + "- The text format requires some preprocessing. (e.g. replace `` with `,`)\n", + "- Add an instruction prompt (e.g. `Can you transcribe the following speech<|audio|>?`)\n", + "- Filter non-verbal examples" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "56d43b23", + "metadata": {}, + "outputs": [], + "source": [ + "def process_gigaspeech_transcript(text):\n", + " text = text.replace(\" \", \",\")\n", + " text = text.replace(\" \", \".\")\n", + " text = text.replace(\" \", \"?\")\n", + " text = text.replace(\" \", \"!\")\n", + " text = text.lower()\n", + " return text\n", + "\n", + "def prep_example(example, tokenizer):\n", + " instruction = \"Please transcribe the following audio to text<|audio|>\"\n", + " chat = [dict(role=\"user\", content=instruction)]\n", + " example[\"prompt\"] = tokenizer.apply_chat_template(\n", + " chat,\n", + " add_generation_prompt=True,\n", + " tokenize=False,\n", + " )\n", + " example[\"text\"] = process_gigaspeech_transcript(example[\"text\"])\n", + " return example\n", + "\n", + "def prepare_dataset(ds, processor):\n", + " columns_to_remove = [col for col in ds.column_names if col not in [\"audio\", \"text\"]]\n", + " ds = ds.cast_column(\"audio\", Audio(sampling_rate=processor.audio_processor.sampling_rate))\n", + " ds = ds.map(prep_example,\n", + " fn_kwargs=dict(tokenizer=processor.tokenizer),\n", + " remove_columns=columns_to_remove,\n", + " )\n", + " ds = ds.filter(lambda x: x[\"text\"] not in [\"\", \"\", \"\", \"\"])\n", + " return ds\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fee81a17", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = prepare_dataset(train_dataset, processor)\n", + "val_dataset = prepare_dataset(val_dataset, processor)\n", + "test_dataset = prepare_dataset(test_dataset, processor)\n" + ] + }, + { + "cell_type": "markdown", + "id": "c27eaee0", + "metadata": {}, + "source": [ + "Let's look at a post-processed example:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f8841c7c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "as they're leaving, can kash pull zahra aside really quickly?\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import Audio\n", + "print(train_dataset[0][\"text\"])\n", + "Audio(data=train_dataset[0][\"audio\"][\"array\"], rate=train_dataset[0][\"audio\"][\"sampling_rate\"])" + ] + }, + { + "cell_type": "markdown", + "id": "2c0cd90e", + "metadata": {}, + "source": [ + "## Running inference + WER computation\n", + "Now let's compute word error rate, for that we'll need to define a collator, which will also be used for finetuning" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e4767018", + "metadata": {}, + "outputs": [], + "source": [ + "import evaluate\n", + "from whisper.normalizers import EnglishTextNormalizer\n", + "from transformers.feature_extraction_utils import BatchFeature\n", + "from torch.utils.data import DataLoader\n", + "import tqdm\n", + "\n", + "class GraniteCollator:\n", + " def __init__(self, processor, inference_mode=False):\n", + " self.processor = processor\n", + " self.inference_mode = inference_mode\n", + "\n", + " def __call__(self, examples):\n", + " prompts = [example[\"prompt\"] for example in examples]\n", + " audios = [example[\"audio\"] for example in examples]\n", + " if isinstance(audios[0], dict):\n", + " audios = [audio[\"array\"] for audio in audios]\n", + "\n", + " processed = self.processor(prompts, audios, return_tensors=\"pt\", padding=True, padding_side=\"left\")\n", + " input_ids = processed.input_ids\n", + " attention_mask = processed.attention_mask\n", + " labels = None\n", + " # tokenize targets\n", + " if not self.inference_mode:\n", + " targets = [example[\"text\"] + self.processor.tokenizer.eos_token for example in examples]\n", + " targets = self.processor.tokenizer(targets, return_tensors=\"pt\", padding=True, padding_side=\"right\")\n", + " # combine prompt+targets\n", + " input_ids = torch.cat([input_ids, targets.input_ids], dim=1)\n", + " attention_mask = torch.cat([attention_mask, targets.attention_mask], dim=1)\n", + " labels = targets.input_ids.clone()\n", + " # Set non-target tokens to -100 for loss calculation\n", + " labels[~(targets.attention_mask.bool())] = -100 \n", + " labels = torch.cat([torch.full_like(processed.input_ids, -100), labels], dim=1)\n", + "\n", + " return BatchFeature(data={\n", + " \"input_ids\": input_ids,\n", + " \"attention_mask\": attention_mask,\n", + " \"labels\": labels,\n", + " \"input_features\": processed.input_features,\n", + " \"input_features_mask\": processed.input_features_mask\n", + " })\n", + "\n", + "def compute_wer(model, processor, cur_dataset):\n", + " collator = GraniteCollator(processor, inference_mode=True)\n", + " dataloader = DataLoader(cur_dataset, batch_size=16, collate_fn=collator, num_workers=8)\n", + " normalizer = EnglishTextNormalizer()\n", + " wer_metric = evaluate.load(\"wer\")\n", + " model = model.eval().cuda()\n", + " \n", + " all_outputs = []\n", + " for batch in tqdm.tqdm(dataloader, desc=\"Running inference\"):\n", + " batch = batch.to(\"cuda\")\n", + " with torch.inference_mode(), torch.amp.autocast(\"cuda\", dtype=torch.bfloat16):\n", + " outputs = model.generate(**batch, max_new_tokens=400, num_beams=4, early_stopping=True)\n", + " input_length = batch.input_ids.shape[1]\n", + " outputs = outputs[:, input_length:].cpu()\n", + " for x in outputs:\n", + " all_outputs.append(processor.tokenizer.decode(x, skip_special_tokens=True))\n", + " \n", + " gt_texts = [normalizer(x) for x in cur_dataset[\"text\"]]\n", + " all_outputs = [normalizer(x) for x in all_outputs]\n", + " wer = wer_metric.compute(references=gt_texts, predictions=all_outputs)\n", + " return wer\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2654b4aa", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Running inference: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00, 2.41s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WER before finetuning 9.719\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "wer_before_train = compute_wer(model, processor, test_dataset)\n", + "print(f\"WER before finetuning {wer_before_train*100:.3f}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "86db15df", + "metadata": {}, + "source": [ + "# Finetuning Granite Speech\n", + "Let's finetune the model on our small training set.\n", + "We'll only tune the LoRA adapters and the projector, to speed up training and avoid overfitting.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "539554a8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [157/157 01:38, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation Loss
161.7474000.841541
321.3605000.657465
480.8650000.531957
640.6953000.511987
800.6261000.504223
960.6246000.500726
1120.5968000.499541
1280.5968000.497704
1440.5949000.498355

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=157, training_loss=0.8314582435948075, metrics={'train_runtime': 99.5279, 'train_samples_per_second': 50.237, 'train_steps_per_second': 1.577, 'total_flos': 1.860231576674304e+16, 'train_loss': 0.8314582435948075, 'epoch': 1.0})" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import TrainingArguments, Trainer\n", + "\n", + "for n, p in model.named_parameters():\n", + " # tranining only the projector/lora layers\n", + " p.requires_grad = \"projector\" in n or \"lora\" in n\n", + "\n", + "args = TrainingArguments(\n", + " output_dir=\"save_dir\",\n", + " remove_unused_columns=False,\n", + " report_to=\"none\",\n", + " bf16=True,\n", + " eval_strategy=\"steps\",\n", + " save_strategy=\"no\",\n", + " eval_steps=0.1,\n", + " dataloader_num_workers=16,\n", + " per_device_train_batch_size=16, \n", + " per_device_eval_batch_size=16, \n", + " gradient_accumulation_steps=2,\n", + " num_train_epochs=1.0,\n", + " warmup_ratio=0.2,\n", + " logging_steps=0.1,\n", + " learning_rate=3e-5,\n", + " data_seed=42,\n", + ")\n", + "data_collator = GraniteCollator(processor)\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=val_dataset,\n", + " data_collator=data_collator,\n", + " processing_class=processor,\n", + ")\n", + "trainer.train()\n" + ] + }, + { + "cell_type": "markdown", + "id": "d6efebd1", + "metadata": {}, + "source": [ + "## Checking for improvements\n", + "Looks like both the training and validation loss are dropping. \n", + "Let's check if the test WER improved by our very lightweight finetuning" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1bf89d54", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Running inference: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:25<00:00, 2.56s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WER after finetuning 9.552\n", + "WER improvement 0.167\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "wer_after_train = compute_wer(model, processor, test_dataset)\n", + "\n", + "print(f\"WER after finetuning {wer_after_train*100:.3f}\")\n", + "print(f\"WER improvement {(wer_before_train - wer_after_train)*100:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ff6fdd3e", + "metadata": {}, + "source": [ + "## Summary\n", + "Hurray! We've managed to slightly improve the WER by quick lightweight finetuning. \n", + "In this notebook you learned how to:\n", + "- Prepare training data for Granite Speech\n", + "- Run batched inference with Granite Speech, and compute Word Error Rate\n", + "- Finetune GraniteSpeech, applying granient update only to the adapter/projector layers\n", + "\n", + "I'd like to thank the following for their help:\n", + "Avishai Elmakies, George Saon, Alexander Brooks, Eliyahu Schwartz" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/en/index.md b/notebooks/en/index.md index 2effb785..58029451 100644 --- a/notebooks/en/index.md +++ b/notebooks/en/index.md @@ -6,7 +6,7 @@ applications and solving various machine learning tasks using open-source tools ## Latest notebooks Check out the recently added notebooks: - +- [Fine-tuning Granite Speech](fine_tuning_granite_speech) - [Fine-tuning T5 for Automatic GitHub Tag Generation with PEFT](finetune_t5_for_search_tag_generation) - [Documentation Chatbot with Meta Synthetic Data Kit](fine_tune_chatbot_docs_synthetic) - [HuatuoGPT-o1 Medical RAG and Reasoning](medical_rag_and_Reasoning) From 66a29632e8dce7ea701d32ca3047e1ef4f02358a Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Sun, 29 Jun 2025 15:58:05 +0000 Subject: [PATCH 2/5] update with used packages --- notebooks/en/fine_tuning_granite_speech.ipynb | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/notebooks/en/fine_tuning_granite_speech.ipynb b/notebooks/en/fine_tuning_granite_speech.ipynb index 8a0bddd7..97bce54c 100644 --- a/notebooks/en/fine_tuning_granite_speech.ipynb +++ b/notebooks/en/fine_tuning_granite_speech.ipynb @@ -10,7 +10,7 @@ "# Finetuning Granite Speech\n", "\n", "[Granite speech](https://huggingface.co/collections/ibm-granite/granite-speech-67e45da088d5092ff6b901c7) is a family of powerful speech models, that excel in speech recognition and speech translation. \n", - "While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (in June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it is easy to finetune on unseen data or add new tasks.\n", + "While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (in June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it is easier to finetune on unseen data or add new tasks.\n", "\n", "In this example, we'll show how to:\n", "1. Run inference with Granite Speech\n", @@ -29,12 +29,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d985dae1", "metadata": {}, "outputs": [], "source": [ - "# todo!" + "# install packages\n", + "!pip install -q git+https://github.com/huggingface/transformers.git\n", + "!pip install -U -q datasets peft accelerate evaluate whisper tqdm\n" ] }, { From d11d148482cc4b10d8ea01c5819218e7504a5910 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Sun, 29 Jun 2025 16:00:43 +0000 Subject: [PATCH 3/5] minor --- notebooks/en/fine_tuning_granite_speech.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/notebooks/en/fine_tuning_granite_speech.ipynb b/notebooks/en/fine_tuning_granite_speech.ipynb index 97bce54c..4b59b898 100644 --- a/notebooks/en/fine_tuning_granite_speech.ipynb +++ b/notebooks/en/fine_tuning_granite_speech.ipynb @@ -16,6 +16,7 @@ "1. Run inference with Granite Speech\n", "2. Evaluate the predictions\n", "3. Finetune the model with new data.\n", + "\n", "Specifically, we'll finetune Granite Speech 2B on [GigaSpeech](https://huggingface.co/datasets/speechcolab/gigaspeech), a large spontaneous conversational dataset which was not included in the model's training. \n" ] }, From 19d89a05831eba702e3dc7d7bb4b3b996cb4fafa Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Sun, 29 Jun 2025 16:25:36 +0000 Subject: [PATCH 4/5] typos --- notebooks/en/fine_tuning_granite_speech.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/en/fine_tuning_granite_speech.ipynb b/notebooks/en/fine_tuning_granite_speech.ipynb index 4b59b898..bb2941a8 100644 --- a/notebooks/en/fine_tuning_granite_speech.ipynb +++ b/notebooks/en/fine_tuning_granite_speech.ipynb @@ -10,7 +10,7 @@ "# Finetuning Granite Speech\n", "\n", "[Granite speech](https://huggingface.co/collections/ibm-granite/granite-speech-67e45da088d5092ff6b901c7) is a family of powerful speech models, that excel in speech recognition and speech translation. \n", - "While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (in June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it is easier to finetune on unseen data or add new tasks.\n", + "While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (as of June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it is easier to finetune on unseen data or add new tasks.\n", "\n", "In this example, we'll show how to:\n", "1. Run inference with Granite Speech\n", @@ -532,7 +532,7 @@ "In this notebook you learned how to:\n", "- Prepare training data for Granite Speech\n", "- Run batched inference with Granite Speech, and compute Word Error Rate\n", - "- Finetune GraniteSpeech, applying granient update only to the adapter/projector layers\n", + "- Finetune GraniteSpeech, applying gradient updates only to the adapter/projector layers\n", "\n", "I'd like to thank the following for their help:\n", "Avishai Elmakies, George Saon, Alexander Brooks, Eliyahu Schwartz" From 185e05fc5b7b17d7477af0d865676c292c1b67d9 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 30 Jun 2025 05:49:07 +0000 Subject: [PATCH 5/5] minor --- notebooks/en/fine_tuning_granite_speech.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/en/fine_tuning_granite_speech.ipynb b/notebooks/en/fine_tuning_granite_speech.ipynb index bb2941a8..2447c925 100644 --- a/notebooks/en/fine_tuning_granite_speech.ipynb +++ b/notebooks/en/fine_tuning_granite_speech.ipynb @@ -10,7 +10,7 @@ "# Finetuning Granite Speech\n", "\n", "[Granite speech](https://huggingface.co/collections/ibm-granite/granite-speech-67e45da088d5092ff6b901c7) is a family of powerful speech models, that excel in speech recognition and speech translation. \n", - "While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (as of June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it is easier to finetune on unseen data or add new tasks.\n", + "While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (as of June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it easier to finetune on unseen data or add new tasks.\n", "\n", "In this example, we'll show how to:\n", "1. Run inference with Granite Speech\n", @@ -125,7 +125,7 @@ "Let's continue with data processing:\n", "- The text format requires some preprocessing. (e.g. replace `` with `,`)\n", "- Add an instruction prompt (e.g. `Can you transcribe the following speech<|audio|>?`)\n", - "- Filter non-verbal examples" + "- Filter non-verbal examples (e.g. ``)" ] }, { @@ -483,7 +483,7 @@ "source": [ "## Checking for improvements\n", "Looks like both the training and validation loss are dropping. \n", - "Let's check if the test WER improved by our very lightweight finetuning" + "Let's check if the test WER improved by our very lightweight finetuning." ] }, { @@ -535,7 +535,7 @@ "- Finetune GraniteSpeech, applying gradient updates only to the adapter/projector layers\n", "\n", "I'd like to thank the following for their help:\n", - "Avishai Elmakies, George Saon, Alexander Brooks, Eliyahu Schwartz" + "Avishai Elmakies, George Saon, Alexander Brooks and Eliyahu Schwartz." ] } ],