diff --git a/.gitignore b/.gitignore index 9fbe219..b8c76cf 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ build crepe/model-*.h5 crepe/model-*.h5.bz2 + +model.onnx diff --git a/crepe/core.py b/crepe/core.py index 03ea52e..77c722f 100644 --- a/crepe/core.py +++ b/crepe/core.py @@ -364,3 +364,32 @@ def process_file(file, output=None, model_capacity='full', viterbi=False, if verbose: print("CREPE: Saved the salience plot at {}".format(plot_file)) + +def export_model_to_onnx(model_capacity='full', output_path='model.onnx', opset=13): + """ + Export a CREPE model to ONNX format + + Parameters + ---------- + model_capacity : 'tiny', 'small', 'medium', 'large', or 'full' + String specifying the model capacity to export + output_path : str + Path where the ONNX model will be saved + opset : int + ONNX opset version to use for conversion + + Returns + ------- + None + """ + import tensorflow as tf + import tf2onnx + + model = build_and_load_model(model_capacity) + spec = (tf.TensorSpec((None, 1024), tf.float32, name="input"),) + onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=opset) + + with open(output_path, "wb") as f: + f.write(onnx_model.SerializeToString()) + + print(f"CREPE: Exported {model_capacity} model to {output_path}") \ No newline at end of file diff --git a/crepe/onnx_export.py b/crepe/onnx_export.py new file mode 100644 index 0000000..d038f53 --- /dev/null +++ b/crepe/onnx_export.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +""" +Script to export CREPE models to ONNX format +""" +import argparse +import os +from crepe.core import export_model_to_onnx + +def main(): + print("Starting script...") + parser = argparse.ArgumentParser(description='Export CREPE model to ONNX format') + parser.add_argument('capacity', choices=['tiny', 'small', 'medium', 'large', 'full'], + help='Model capacity to export') + parser.add_argument('-o', '--output', default=None, + help='Output path for the ONNX model (default: model-{capacity}.onnx)') + parser.add_argument('--opset', type=int, default=13, + help='ONNX opset version (default: 13)') + + args = parser.parse_args() + + output_path = args.output + if output_path is None: + output_path = f"model-{args.capacity}.onnx" + + print(f"Exporting {args.capacity} model to {output_path}...") + export_model_to_onnx( + model_capacity=args.capacity, + output_path=output_path, + opset=args.opset + ) + + print("Export complete!") + +if __name__ == "__main__": + import tensorflow as tf + print(f"TensorFlow version: {tf.__version__}") + main() \ No newline at end of file