Skip to content

Whisper Redesigned Solution #1229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
c2c8745
Rename Whisper encoder input to audio features
kunal-vaishnavi Sep 30, 2024
1d5f4f0
Initial commit for new export
kunal-vaishnavi Oct 22, 2024
5bf4628
Fix KV cache initialization and runtime bugs
kunal-vaishnavi Nov 2, 2024
3cb936e
Add another check for alignment heads input
kunal-vaishnavi Nov 5, 2024
b648f58
Dump logits in ORT GenAI
kunal-vaishnavi Nov 7, 2024
2a5b762
Fix cross QK update
kunal-vaishnavi Nov 14, 2024
e24db74
Fix finalize cross QK
kunal-vaishnavi Nov 15, 2024
e4c838e
Save checkpoint for working solution
kunal-vaishnavi Nov 15, 2024
3a548a1
Clean up code
kunal-vaishnavi Nov 17, 2024
4d9af67
Remove unneeded template instantiations
kunal-vaishnavi Nov 21, 2024
1d9161d
Fixes: update crossQK copy for first step;
mindest Nov 27, 2024
97be76a
Enable getting model inputs to user
kunal-vaishnavi Dec 4, 2024
1bcd264
Add additional check for cache indirection
kunal-vaishnavi Dec 6, 2024
c35a73d
Add audio processing unit test
kunal-vaishnavi Dec 18, 2024
1d5da61
Fix Whisper GenAI config
kunal-vaishnavi Dec 18, 2024
efd0199
Save checkpoint for working solution
kunal-vaishnavi Dec 21, 2024
fbebe68
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Dec 21, 2024
ef955e7
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Feb 5, 2025
e869d02
Squashed commit of the following:
kunal-vaishnavi Feb 6, 2025
32c48d2
Initial changes to work with main
kunal-vaishnavi Feb 17, 2025
e4a8b5f
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Feb 17, 2025
323028a
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Mar 24, 2025
7756a86
Resolving build errors after merging main
kunal-vaishnavi Mar 25, 2025
a167add
Fix prompt length and get input
kunal-vaishnavi Mar 28, 2025
8782b47
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Mar 28, 2025
c93a1ab
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Apr 25, 2025
c0efa93
Fix build issues after syncing with main
kunal-vaishnavi Apr 26, 2025
27ba626
Add gpt2 to list of LLMs
kunal-vaishnavi Apr 26, 2025
2eb198b
Cast from ORT float16 to uint16 and then uint16 to half
kunal-vaishnavi Apr 26, 2025
d4e7446
Remove const casting
kunal-vaishnavi Apr 26, 2025
8558cab
Fix windows build errors
kunal-vaishnavi Apr 26, 2025
40a555a
Update processing for audio features
kunal-vaishnavi Apr 28, 2025
4c3752c
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi May 7, 2025
5fa7300
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Jun 10, 2025
198db2d
Fix duplicate config names after merging main
kunal-vaishnavi Jun 10, 2025
3264244
Add comments to C API process methods
kunal-vaishnavi Jun 13, 2025
83a915d
Move SetInputs from params to generator
kunal-vaishnavi Jun 13, 2025
c876bd2
Use SetExtraInputs for all states
kunal-vaishnavi Jun 17, 2025
095d452
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Jun 17, 2025
67359da
Fix build errors after merging main
kunal-vaishnavi Jun 17, 2025
59ae78d
Align spacing for comment
kunal-vaishnavi Jun 17, 2025
c12d16b
Add extra inputs back to decoder only state
kunal-vaishnavi Jun 17, 2025
eec4bba
Always call SetExtraInputs
kunal-vaishnavi Jun 18, 2025
da6fc9a
Add audio processing APIs in other languages
kunal-vaishnavi Jun 18, 2025
1e29004
Comment out multi-prompt APIs for now
kunal-vaishnavi Jun 18, 2025
a0d5be7
Fix Java build issues with new audio classes
kunal-vaishnavi Jun 18, 2025
c474d62
Add missing Objective-C interfaces for new audio classes
kunal-vaishnavi Jun 18, 2025
eb88379
Fix variable names in setting inputs for Java API
kunal-vaishnavi Jun 18, 2025
502afb3
Update Java unit tests
kunal-vaishnavi Jun 18, 2025
c7a865b
Fix tensor unit test in Java
kunal-vaishnavi Jun 18, 2025
6edf8d3
Add C/C++ APIs to set batched input ids
kunal-vaishnavi Jun 19, 2025
aed1720
Start updating Whisper inference examples
kunal-vaishnavi Jun 19, 2025
eb285b9
Update Whisper examples and add Python pre-processing binding
kunal-vaishnavi Jun 20, 2025
7edc027
Update audio preprocessing unit tests
kunal-vaishnavi Jun 20, 2025
85689af
Add changes suggested by clang-format and CodeQL
kunal-vaishnavi Jun 20, 2025
29cf80b
Remove extra newline for clang-format
kunal-vaishnavi Jun 20, 2025
d1a7608
Add Python CI test for Whisper
kunal-vaishnavi Jun 20, 2025
fe27ea1
Fix cache indirection updating
kunal-vaishnavi Jun 21, 2025
8883492
Fix build warning in Windows CIs
kunal-vaishnavi Jun 21, 2025
1839220
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Jun 21, 2025
92b32d9
Remove commented out code
kunal-vaishnavi Jun 24, 2025
9bcc681
Use feature extraction instead of speech log mel
kunal-vaishnavi Jun 24, 2025
1a1ef86
Fix variable names based on PR feedback
kunal-vaishnavi Jun 24, 2025
9577e36
Fix import name for E2E unit tests
kunal-vaishnavi Jun 24, 2025
96251cd
Update ORT extensions commit
kunal-vaishnavi Jun 25, 2025
8e40be7
Only transpose K caches when DMMHA is used
kunal-vaishnavi Jun 25, 2025
17cc672
Fix extra inputs usage for pipeline and GPT models
kunal-vaishnavi Jun 26, 2025
8fb2f1b
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Jun 26, 2025
0aa011b
Move SetExtraInputs to the right state
kunal-vaishnavi Jun 26, 2025
2681f35
Access sessions field through model object
kunal-vaishnavi Jun 26, 2025
7f283a4
Rewrite batched preprocessing APIs
kunal-vaishnavi Jun 27, 2025
a7cbacc
Use different C++ API call for one prompt in preprocessing
kunal-vaishnavi Jun 27, 2025
9989264
Remove vector usage for C-only environment in Java bindings
kunal-vaishnavi Jun 27, 2025
4ca81bd
Cast pybind str to std string
kunal-vaishnavi Jun 27, 2025
04513a2
Remove OgaCheckResult from Java bindings
kunal-vaishnavi Jun 27, 2025
4a76e15
Fix typo in Java doc string
kunal-vaishnavi Jun 27, 2025
b316861
Fix NativeMethods function name
kunal-vaishnavi Jun 27, 2025
5141f00
Add changes suggested by linters
kunal-vaishnavi Jun 27, 2025
2e3560a
Change how strdup is defined
kunal-vaishnavi Jun 27, 2025
d977340
Add changes from PR feedback
kunal-vaishnavi Jun 28, 2025
363ca7c
Activate Whisper E2E CI tests
kunal-vaishnavi Jun 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/linux-gpu-x64-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ jobs:
docker run \
--gpus all \
--rm \
--volume /data/ortgenai/pytorch:/data/ortgenai/pytorch \
--volume /data/ortgenai/:/data/ortgenai/ \
--volume $GITHUB_WORKSPACE:/ort_genai_src \
-e HF_TOKEN=$HF_TOKEN \
-w /ort_genai_src onnxruntimecudabuildx64 bash -c " \
Expand All @@ -170,6 +170,6 @@ jobs:
docker run \
--gpus all \
--rm \
--volume /data/ortgenai/pytorch:/data/ortgenai/pytorch \
--volume /data/ortgenai/:/data/ortgenai/ \
--volume $GITHUB_WORKSPACE:/ort_genai_src \
-w /ort_genai_src onnxruntimecudabuildx64 bash -c "ORTGENAI_LOG_ORT_LIB=1 LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/ort_genai_src/build/cuda/ /ort_genai_src/build/cuda/unit_tests"
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;5ea4b9b0683b83c1d6800eb332f37dcc76bb2e61
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;a85fa861ee5e5300f16142bd969ede0eabc61c86

# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
llguidance;https://github.com/microsoft/llguidance.git;2d2f1de3c87e3289528affc346f734f7471216d9
Expand Down
2 changes: 1 addition & 1 deletion examples/c/src/phi3v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
std::cout << "Generating response..." << std::endl;
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 7680);
params->SetInputs(*input_tensors);

auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*input_tensors);

while (!generator->IsDone()) {
generator->GenerateNextToken();
Expand Down
2 changes: 1 addition & 1 deletion examples/c/src/phi4-mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
std::cout << "Generating response..." << std::endl;
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 7680);
params->SetInputs(*input_tensors);

auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*input_tensors);

while (!generator->IsDone()) {
generator->GenerateNextToken();
Expand Down
58 changes: 23 additions & 35 deletions examples/c/src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ void CXX_API(const char* model_path, int32_t num_beams) {
auto model = OgaModel::Create(model_path);
std::cout << "Creating multimodal processor..." << std::endl;
auto processor = OgaMultiModalProcessor::Create(*model);
std::cout << "Creating tokenizer..." << std::endl;
auto tokenizer = OgaTokenizer::Create(*model);

while (true) {
std::string audio_paths_str;
Expand All @@ -42,31 +40,24 @@ void CXX_API(const char* model_path, int32_t num_beams) {
audios = OgaAudios::Load(audio_paths_c);
}

std::cout << "Processing audio..." << std::endl;
auto mel = processor->ProcessAudios(audios.get());
const std::vector<const char*> prompt_tokens = {"<|startoftranscript|>", "<|en|>", "<|transcribe|>",
"<|notimestamps|>"};
auto input_ids = OgaSequences::Create();
std::cout << "Processing inputs..." << std::endl;
const size_t batch_size = audio_paths.size();
for (size_t i = 0; i < batch_size; ++i) {
for (const auto& token : prompt_tokens) {
input_ids->Append(tokenizer->ToTokenId(token), i);
}
}
const char* prompt_tokens = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>";
const std::vector<const char*> prompts(batch_size, prompt_tokens);
auto inputs = processor->ProcessAudios(prompts, audios.get());

std::cout << "Generating response..." << std::endl;
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 256);
params->SetSearchOption("batch_size", static_cast<double>(batch_size));
params->SetSearchOption("max_length", 448);
params->SetSearchOptionBool("do_sample", false);
params->SetSearchOption("num_beams", num_beams);
params->SetSearchOption("num_return_sequences", num_beams);
params->SetInputs(*mel);
params->SetInputSequences(*input_ids);

auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*inputs);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();
}

Expand Down Expand Up @@ -133,36 +124,29 @@ void C_API(const char* model_path, int32_t num_beams) {
}

std::cout << "Processing audio..." << std::endl;
OgaNamedTensors* mel;
CheckResult(OgaProcessorProcessAudios(processor, audios, &mel));
const std::vector<const char*> prompt_tokens = {"<|startoftranscript|>", "<|en|>", "<|transcribe|>",
"<|notimestamps|>"};
OgaSequences* input_ids;
CheckResult(OgaCreateSequences(&input_ids));
OgaNamedTensors* inputs;
const size_t batch_size = audio_paths.size();
for (size_t i = 0; i < batch_size; ++i) {
for (const auto& token : prompt_tokens) {
int32_t token_id;
CheckResult(OgaTokenizerToTokenId(tokenizer, token, &token_id));
CheckResult(OgaAppendTokenToSequence(token_id, input_ids, i));
}
}
const char* prompt_tokens = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>";
std::vector<const char*> prompts(batch_size, prompt_tokens);
OgaStringArray* prompts_string_array;
CheckResult(OgaCreateStringArrayFromStrings(prompts.data(), prompts.size(), &prompts_string_array));
CheckResult(OgaProcessorProcessAudiosAndPrompts(processor, prompts_string_array, audios, &inputs));
OgaDestroyStringArray(prompts_string_array);

std::cout << "Generating response..." << std::endl;
OgaGeneratorParams* params;
CheckResult(OgaCreateGeneratorParams(model, &params));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 256));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "batch_size", static_cast<double>(batch_size)));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 448));
CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", false));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "num_beams", num_beams));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "num_return_sequences", num_beams));
CheckResult(OgaGeneratorParamsSetInputs(params, mel));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_ids));

OgaGenerator* generator;
CheckResult(OgaCreateGenerator(model, params, &generator));
CheckResult(OgaGenerator_SetInputs(generator, inputs));

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken(generator));
}

Expand All @@ -182,8 +166,7 @@ void C_API(const char* model_path, int32_t num_beams) {

OgaDestroyGenerator(generator);
OgaDestroyGeneratorParams(params);
OgaDestroySequences(input_ids);
OgaDestroyNamedTensors(mel);
OgaDestroyNamedTensors(inputs);
OgaDestroyAudios(audios);
}

Expand All @@ -203,6 +186,11 @@ int main(int argc, char** argv) {
return -1;
}

// Uncomment for debugging purposes
// Oga::SetLogBool("enabled", true);
// Oga::SetLogBool("model_input_values", true);
// Oga::SetLogBool("model_output_values", true);

std::cout << "---------------" << std::endl;
std::cout << "Hello, Whisper!" << std::endl;
std::cout << "---------------" << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion examples/csharp/HelloPhi3V/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ void PrintUsage()
Console.WriteLine("Generating response...");
using GeneratorParams generatorParams = new GeneratorParams(model);
generatorParams.SetSearchOption("max_length", 7680);
generatorParams.SetInputs(inputTensors);

using var generator = new Generator(model, generatorParams);
generator.SetInputs(inputTensors);
var watch = System.Diagnostics.Stopwatch.StartNew();
while (!generator.IsDone())
{
Expand Down
2 changes: 1 addition & 1 deletion examples/csharp/HelloPhi4MM/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ void PrintUsage()
Console.WriteLine("Generating response...");
using GeneratorParams generatorParams = new GeneratorParams(model);
generatorParams.SetSearchOption("max_length", 7680);
generatorParams.SetInputs(inputTensors);

using var generator = new Generator(model, generatorParams);
generator.SetInputs(inputTensors);
var watch = System.Diagnostics.Stopwatch.StartNew();
while (!generator.IsDone())
{
Expand Down
3 changes: 2 additions & 1 deletion examples/python/model-vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

import onnxruntime_genai as og
# og.set_log_options(enabled=True, model_input_values=True, model_output_values=True)

def _find_dir_contains_sub_dir(current_dir: Path, target_dir_name):
curr_path = Path(current_dir).absolute()
Expand Down Expand Up @@ -103,10 +104,10 @@ def run(args: argparse.Namespace):

print("Generating response...")
params = og.GeneratorParams(model)
params.set_inputs(inputs)
params.set_search_options(max_length=7680)

generator = og.Generator(model, params)
generator.set_inputs(inputs)
start_time = time.time()

while not generator.is_done():
Expand Down
2 changes: 1 addition & 1 deletion examples/python/phi4-mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def run(args: argparse.Namespace):

print("Generating response...")
params = og.GeneratorParams(model)
params.set_inputs(inputs)
params.set_search_options(max_length=7680)

generator = og.Generator(model, params)
generator.set_inputs(inputs)
start_time = time.time()

while not generator.is_done():
Expand Down
56 changes: 46 additions & 10 deletions examples/python/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import readline

import onnxruntime_genai as og

# og.set_log_options(enabled=True, model_input_values=True, model_output_values=True)

def _complete(text, state):
return (glob.glob(text + "*") + [None])[state]
Expand All @@ -20,15 +20,25 @@ class Format:

def run(args: argparse.Namespace):
print("Loading model...")
model = og.Model(args.model_path)
config = og.Config(args.model_path)
if args.execution_provider != "follow_config":
config.clear_providers()
if args.execution_provider != "cpu":
print(f"Setting model to {args.execution_provider}")
config.append_provider(args.execution_provider)
model = og.Model(config)
processor = model.create_multimodal_processor()
tokenizer = og.Tokenizer(model)

while True:
readline.set_completer_delims(" \t\n;")
readline.parse_and_bind("tab: complete")
readline.set_completer(_complete)
audio_paths = [audio_path.strip() for audio_path in input("Audio Paths (comma separated): ").split(",")]

if args.non_interactive:
audio_paths = [args.audio]
else:
audio_paths = [audio_path.strip() for audio_path in input("Audio Paths (comma separated): ").split(",")]
if len(audio_paths) == 0:
raise ValueError("No audio provided.")

Expand All @@ -39,28 +49,27 @@ def run(args: argparse.Namespace):
audios = og.Audios.open(*audio_paths)

print("Processing audio...")
mel = processor(audios=audios)
batch_size = len(audio_paths)
decoder_prompt_tokens = ["<|startoftranscript|>", "<|en|>", "<|transcribe|>", "<|notimestamps|>"]
prompts = ["".join(decoder_prompt_tokens)] * batch_size
inputs = processor(prompts, audios=audios)

params = og.GeneratorParams(model)
params.set_search_options(
do_sample=False,
num_beams=args.num_beams,
num_return_sequences=args.num_beams,
max_length=256,
max_length=448,
)

batch_size = len(audio_paths)
params.set_inputs(mel)
params.input_ids = [[tokenizer.to_token_id(token) for token in decoder_prompt_tokens]] * batch_size

generator = og.Generator(model, params)
generator.set_inputs(inputs)

while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()

print()
transcriptions = []
for i in range(batch_size * args.num_beams):
tokens = generator.get_sequence(i)
transcription = processor.decode(tokens)
Expand All @@ -69,18 +78,45 @@ def run(args: argparse.Namespace):
print(
f" {Format.underline}batch {i // args.num_beams}, beam {i % args.num_beams}{Format.end}: {transcription}"
)
transcriptions.append(transcription.strip())

for _ in range(3):
print()

if args.non_interactive:
args.output = args.output.strip()
matching = False
for transcription in transcriptions:
if transcription == args.output:
matching = True
break

if matching:
print("One of the model's transcription matches the expected transcription.")
return
raise Exception("None of the model's transcriptions match the expected transcription.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model_path", type=str, required=True, help="Path to the model"
)
parser.add_argument(
'-e', '--execution_provider', type=str, required=False, default='follow_config', choices=["cpu", "cuda", "follow_config"],
help="Execution provider to run the ONNX Runtime session with. Defaults to follow_config that uses the execution provider listed in the genai_config.json instead."
)
parser.add_argument(
"-b", "--num_beams", type=int, default=4, help="Number of beams"
)
parser.add_argument(
"-a", "--audio", type=str, default="", help="Path to audio file for CI testing purposes"
)
parser.add_argument(
"-o", "--output", type=str, default="", help="Expected transcribed output for CI testing purposes"
)
parser.add_argument(
"-ni", "--non_interactive", default=False, action="store_true", help="Non-interactive mode for CI testing purposes"
)
args = parser.parse_args()
run(args)
Loading
Loading