diff --git a/Cargo.lock b/Cargo.lock
index 3c2324ac..8661b1cd 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -419,6 +419,7 @@ dependencies = [
"graphviz-rust",
"half",
"hf-hub",
+ "hound",
"image",
"log",
"memmap2",
@@ -427,6 +428,7 @@ dependencies = [
"open-hypergraphs",
"open-hypergraphs-dot",
"rayon",
+ "rustfft",
"safetensors 0.7.0",
"serde",
"serde_json",
@@ -1483,6 +1485,12 @@ dependencies = [
"windows-sys 0.60.2",
]
+[[package]]
+name = "hound"
+version = "3.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
+
[[package]]
name = "http"
version = "1.4.0"
@@ -2435,6 +2443,15 @@ dependencies = [
"zerocopy",
]
+[[package]]
+name = "primal-check"
+version = "0.3.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08"
+dependencies = [
+ "num-integer",
+]
+
[[package]]
name = "proc-macro2"
version = "1.0.103"
@@ -2740,6 +2757,20 @@ dependencies = [
"windows-sys 0.52.0",
]
+[[package]]
+name = "rustfft"
+version = "6.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "21db5f9893e91f41798c88680037dba611ca6674703c1a18601b01a72c8adb89"
+dependencies = [
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "primal-check",
+ "strength_reduce",
+ "transpose",
+]
+
[[package]]
name = "rustix"
version = "1.1.2"
@@ -3004,6 +3035,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
+[[package]]
+name = "strength_reduce"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
+
[[package]]
name = "strsim"
version = "0.11.1"
@@ -3352,6 +3389,16 @@ dependencies = [
"tracing-log",
]
+[[package]]
+name = "transpose"
+version = "0.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e"
+dependencies = [
+ "num-integer",
+ "strength_reduce",
+]
+
[[package]]
name = "typed-builder"
version = "0.23.2"
diff --git a/catgrad-llm/Cargo.toml b/catgrad-llm/Cargo.toml
index 3143227e..7c223487 100644
--- a/catgrad-llm/Cargo.toml
+++ b/catgrad-llm/Cargo.toml
@@ -32,6 +32,8 @@ serde_with = { version = "3.17", default-features = false, features = ["macros"]
serde_path_to_error = "0.1"
ureq = "2.12.1"
url = "2.5.7"
+hound = "3.5.1"
+rustfft = "6.4.1"
[dev-dependencies]
diff --git a/catgrad-llm/scripts/llm.py b/catgrad-llm/scripts/llm.py
index 07f262e1..aa804a2d 100644
--- a/catgrad-llm/scripts/llm.py
+++ b/catgrad-llm/scripts/llm.py
@@ -10,7 +10,7 @@
from transformers import (
AutoModelForCausalLM,
- AutoModelForImageTextToText,
+ AutoModelForMultimodalLM,
AutoProcessor,
AutoTokenizer,
logging,
@@ -247,6 +247,7 @@ def run_tool_chat(tokenizer, model, prompt, args):
parser.add_argument("-p", "--prompt", type=str, default="Category theory is")
parser.add_argument("-s", "--seq-len", type=int, default=10)
parser.add_argument("-i", "--image", type=str, default=None)
+ parser.add_argument("-a", "--audio", type=str, default=None)
parser.add_argument("-r", "--raw", action="store_true")
parser.add_argument("-t", "--thinking", action="store_true")
parser.add_argument(
@@ -266,19 +267,19 @@ def run_tool_chat(tokenizer, model, prompt, args):
if args.tool_use and args.raw:
parser.error("--tool-use does not support --raw")
- if args.image is None:
+ if args.image is None and args.audio is None:
tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision)
try:
model = AutoModelForCausalLM.from_pretrained(
args.model, revision=args.revision, dtype=args.dtype
)
except:
- model = AutoModelForImageTextToText.from_pretrained(
+ model = AutoModelForMultimodalLM.from_pretrained(
args.model, revision=args.revision, dtype=args.dtype
)
else:
processor = AutoProcessor.from_pretrained(args.model, revision=args.revision)
- model = AutoModelForImageTextToText.from_pretrained(
+ model = AutoModelForMultimodalLM.from_pretrained(
args.model, revision=args.revision, dtype=args.dtype
)
@@ -290,6 +291,7 @@ def run_tool_chat(tokenizer, model, prompt, args):
if (
args.image is None
+ and args.audio is None
and not args.raw
and not args.tool_use
and tokenizer.chat_template is not None
@@ -309,7 +311,7 @@ def run_tool_chat(tokenizer, model, prompt, args):
model.generation_config.top_p = None
model.generation_config.top_k = None
- if args.image is None:
+ if args.image is None and args.audio is None:
if args.tool_use:
output = run_tool_chat(tokenizer, model, prompt, args)
else:
@@ -322,13 +324,16 @@ def run_tool_chat(tokenizer, model, prompt, args):
)
output = tokenizer.decode(logits[0], skip_special_tokens=True)
else:
+ content = [{"type": "text", "text": prompt}]
+ if args.image:
+ content += [{"type": "image", "path": args.image}]
+ if args.audio:
+ content += [{"type": "audio", "path": args.audio}]
+
messages = [
{
"role": "user",
- "content": [
- {"type": "text", "text": prompt},
- {"type": "image", "path": args.image},
- ],
+ "content": content,
}
]
try:
diff --git a/catgrad-llm/src/helpers/tool_calls.rs b/catgrad-llm/src/helpers/tool_calls.rs
index bd19b8d2..c554356a 100644
--- a/catgrad-llm/src/helpers/tool_calls.rs
+++ b/catgrad-llm/src/helpers/tool_calls.rs
@@ -50,6 +50,10 @@ pub fn parse_qwen3_5_tool_calls(output: &str) -> Result