Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ build

crepe/model-*.h5
crepe/model-*.h5.bz2

model.onnx
29 changes: 29 additions & 0 deletions crepe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,32 @@ def process_file(file, output=None, model_capacity='full', viterbi=False,
if verbose:
print("CREPE: Saved the salience plot at {}".format(plot_file))


def export_model_to_onnx(model_capacity='full', output_path='model.onnx', opset=13):
"""
Export a CREPE model to ONNX format

Parameters
----------
model_capacity : 'tiny', 'small', 'medium', 'large', or 'full'
String specifying the model capacity to export
output_path : str
Path where the ONNX model will be saved
opset : int
ONNX opset version to use for conversion

Returns
-------
None
"""
import tensorflow as tf
import tf2onnx

model = build_and_load_model(model_capacity)
spec = (tf.TensorSpec((None, 1024), tf.float32, name="input"),)
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=opset)

with open(output_path, "wb") as f:
f.write(onnx_model.SerializeToString())

print(f"CREPE: Exported {model_capacity} model to {output_path}")
37 changes: 37 additions & 0 deletions crepe/onnx_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python3
"""
Script to export CREPE models to ONNX format
"""
import argparse
import os
from crepe.core import export_model_to_onnx

def main():
print("Starting script...")
parser = argparse.ArgumentParser(description='Export CREPE model to ONNX format')
parser.add_argument('capacity', choices=['tiny', 'small', 'medium', 'large', 'full'],
help='Model capacity to export')
parser.add_argument('-o', '--output', default=None,
help='Output path for the ONNX model (default: model-{capacity}.onnx)')
parser.add_argument('--opset', type=int, default=13,
help='ONNX opset version (default: 13)')

args = parser.parse_args()

output_path = args.output
if output_path is None:
output_path = f"model-{args.capacity}.onnx"

print(f"Exporting {args.capacity} model to {output_path}...")
export_model_to_onnx(
model_capacity=args.capacity,
output_path=output_path,
opset=args.opset
)

print("Export complete!")

if __name__ == "__main__":
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
main()