From 59a96768f29e516a22aa64eb33412218793a724e Mon Sep 17 00:00:00 2001 From: georgewhewell Date: Tue, 28 Apr 2026 20:40:12 +0200 Subject: [PATCH 1/3] utils: support loading model files from a local directory `get_model_files` and `get_model_chat_template` now treat the model identifier as a local directory if it's an existing path on disk; that directory must look like a HuggingFace snapshot (config.json, tokenizer.json, tokenizer_config.json, and either model.safetensors or model.safetensors.index.json + shards). Otherwise the existing HF hub download path is used unchanged. --- catgrad-llm/src/utils/mod.rs | 110 +++++++++++++++++++++++++++++------ 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/catgrad-llm/src/utils/mod.rs b/catgrad-llm/src/utils/mod.rs index 9641b06d..db9a2715 100644 --- a/catgrad-llm/src/utils/mod.rs +++ b/catgrad-llm/src/utils/mod.rs @@ -59,10 +59,22 @@ pub fn from_json_reader(reader: R) -> Result { serde_path_to_error::deserialize(&mut deserializer).map_err(LLMError::from) } +/// Resolve a model identifier into the four file paths catgrad-llm needs. +/// +/// If `model` points to an existing local directory, files are read from +/// there (the directory must look like a HuggingFace snapshot — at minimum +/// `config.json`, `tokenizer.json`, `tokenizer_config.json`, and either +/// `model.safetensors` or `model.safetensors.index.json` plus its shards). +/// Otherwise `model` is treated as a HuggingFace repo id and downloaded. pub fn get_model_files( model: &str, revision: &str, ) -> Result<(Vec, PathBuf, PathBuf, PathBuf)> { + let local = Path::new(model); + if local.is_dir() { + return local_model_files(local); + } + let api = build_hf_api()?; let repo = api.repo(Repo::with_revision( model.to_string(), @@ -103,28 +115,87 @@ pub fn get_model_files( Ok((m, c, t, tc)) } +fn local_chat_template(dir: &Path) -> Result { + let jinja = dir.join("chat_template.jinja"); + if jinja.is_file() { + return Ok(std::fs::read_to_string(jinja)?); + } + let tc_path = dir.join("tokenizer_config.json"); + let tc = std::fs::read_to_string(&tc_path)?; + let tokenizer_config: serde_json::Value = from_json_str(&tc)?; + tokenizer_config + .get("chat_template") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or(LLMError::InvalidModelConfig( + "Missing or invalid `chat_template` in tokenizer config".to_string(), + )) +} + +fn local_model_files(dir: &Path) -> Result<(Vec, PathBuf, PathBuf, PathBuf)> { + let must_exist = |name: &str| -> Result { + let p = dir.join(name); + if p.is_file() { + Ok(p) + } else { + Err(LLMError::InvalidModelConfig(format!( + "local model dir {} missing required file `{name}`", + dir.display() + ))) + } + }; + + let weights = if dir.join("model.safetensors.index.json").is_file() { + let index = std::fs::File::open(dir.join("model.safetensors.index.json"))?; + let json: serde_json::Value = from_json_reader(index)?; + let weight_map = json.get("weight_map").and_then(|v| v.as_object()).ok_or( + LLMError::InvalidModelConfig("local index missing or invalid `weight_map`".to_string()), + )?; + let mut shards = HashSet::new(); + for v in weight_map.values() { + let name = v.as_str().ok_or(LLMError::InvalidModelConfig( + "weight_map contained non-string values".to_string(), + ))?; + shards.insert(dir.join(name)); + } + shards.into_iter().collect() + } else { + vec![must_exist("model.safetensors")?] + }; + + Ok(( + weights, + must_exist("config.json")?, + must_exist("tokenizer.json")?, + must_exist("tokenizer_config.json")?, + )) +} + // Try getting the model's chat template from the repository pub fn get_model_chat_template(model: &str, revision: &str) -> Result { - let api = build_hf_api()?; - let repo = api.repo(Repo::with_revision( - model.to_string(), - RepoType::Model, - revision.to_string(), - )); - - let chat_template = if let Ok(ct) = repo.get("chat_template.jinja") { - std::fs::read_to_string(ct)? + let chat_template = if Path::new(model).is_dir() { + local_chat_template(Path::new(model))? } else { - let tc_path = repo.get("tokenizer_config.json")?; - let tc = std::fs::read_to_string(tc_path)?; - let tokenizer_config: serde_json::Value = from_json_str(&tc)?; - tokenizer_config - .get("chat_template") - .and_then(|v| v.as_str()) - .ok_or(LLMError::InvalidModelConfig( - "Missing or invalid `chat_template` in tokenizer config".to_string(), - ))? - .to_string() + let api = build_hf_api()?; + let repo = api.repo(Repo::with_revision( + model.to_string(), + RepoType::Model, + revision.to_string(), + )); + if let Ok(ct) = repo.get("chat_template.jinja") { + std::fs::read_to_string(ct)? + } else { + let tc_path = repo.get("tokenizer_config.json")?; + let tc = std::fs::read_to_string(tc_path)?; + let tokenizer_config: serde_json::Value = from_json_str(&tc)?; + tokenizer_config + .get("chat_template") + .and_then(|v| v.as_str()) + .ok_or(LLMError::InvalidModelConfig( + "Missing or invalid `chat_template` in tokenizer config".to_string(), + ))? + .to_string() + } }; // Some chat templates contain these tags that are not used for inference. // If more variants show up a regex may be needed later on. @@ -382,6 +453,7 @@ pub fn get_model( dtype, )?), "GPT2LMHeadModel" => Box::new(models::gpt2::GPT2Model::new(config_json, dtype)?), + "TalkieForCausalLM" => Box::new(models::talkie::TalkieModel::new(config_json, dtype)?), _ => { return Err(LLMError::InvalidModelConfig(format!( "Unsupported model architecture: {}", From 0659f01fd388690cd9fd5c2887cf059859402153 Mon Sep 17 00:00:00 2001 From: georgewhewell Date: Tue, 28 Apr 2026 20:40:30 +0200 Subject: [PATCH 2/3] Talkie 13B model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Talkie is a 40-layer/40-head decoder-only transformer (talkie-lm.com, github.com/talkie-lm/talkie) with the standard Llama backbone plus four small departures, all expressible with existing catgrad operators: 1. RMSNorm everywhere is unweighted (F.rms_norm with no gamma), including a norm immediately after the embedding. 2. QK-norm — RMSNorm is applied to Q and K after RoPE. 3. Per-head and per-layer learned gains — head_gain ([H]) on Q after QK-norm, and scalar attn_gain / mlp_gain / embed_skip on the residual branches. 4. Embedding-skip residual — the post-input-norm activations are threaded through every block as e_x and added back via a learned scalar. The lm_head is an untied [V, D] parameter (not a Linear) scaled by a learned scalar (lm_head_gain.w_g) before the final matmul. Talkie's RoPE uses the opposite sin convention from catgrad's default; we negate cache.sin once after init to match. Architecture string: TalkieForCausalLM. End-to-end inference reproduces the upstream pytorch reference byte-for-byte at greedy argmax for short sequences in bf16; on longer sequences the cross-implementation bf16 noise floor (Metal vs CPU) flips one borderline argmax per ~40 tokens on some prompts. Test harness in scripts/compare/talkie_compare.sh. Helpers: - scripts/convert_talkie.py: pickle -> safetensors + tokenizer + config - scripts/llm_talkie.py: greedy-argmax pytorch reference - scripts/compare/talkie_compare.sh: token-level stability matrix --- catgrad-llm/scripts/compare/talkie_compare.sh | 267 +++++++++++++ catgrad-llm/scripts/convert_talkie.py | 288 ++++++++++++++ catgrad-llm/scripts/llm_talkie.py | 150 +++++++ catgrad-llm/src/models/mod.rs | 1 + catgrad-llm/src/models/talkie.rs | 369 ++++++++++++++++++ 5 files changed, 1075 insertions(+) create mode 100755 catgrad-llm/scripts/compare/talkie_compare.sh create mode 100644 catgrad-llm/scripts/convert_talkie.py create mode 100644 catgrad-llm/scripts/llm_talkie.py create mode 100644 catgrad-llm/src/models/talkie.rs diff --git a/catgrad-llm/scripts/compare/talkie_compare.sh b/catgrad-llm/scripts/compare/talkie_compare.sh new file mode 100755 index 00000000..5ab9765c --- /dev/null +++ b/catgrad-llm/scripts/compare/talkie_compare.sh @@ -0,0 +1,267 @@ +#!/usr/bin/env bash +# Stability harness for the Talkie port. +# +# Token-by-token diffs the catgrad implementation against the talkie +# package's PyTorch reference. The reference loader holds the model in +# RAM for the whole matrix (one load, all cases) — without that, each +# 13B Python process spikes to ~50 GB and back-to-back runs OOM. +# +# Three outputs per case: +# ref — talkie / pytorch (CPU, dtype-cast bf16 by default) +# cat-k — catgrad-llm with KV cache +# cat-nok — catgrad-llm without KV cache +# +# Comparisons: +# ref vs cat-k — does the port match the upstream reference? +# This is the headline correctness check. +# cat-k vs cat-nok — does catgrad's cache implementation match its +# uncached path? In bf16 this can drift by one +# argmax tie-break and is informative but not +# strictly a talkie-level concern. +# +# Modes: +# * Single-shot — set TALKIE_PROMPT or TALKIE_SEQLEN; runs that one case. +# * Matrix — default. Built-in suite of prompts × lengths. +# +# Required env: +# TALKIE_DIR directory with the converted model (config.json, +# model.safetensors, tokenizer.json, ...) +# TALKIE_VOCAB path to the original vocab.txt +# TALKIE_VENV path to a venv with the `talkie` package installed +# (`pip install git+https://github.com/talkie-lm/talkie`) +# +# Optional env: +# TALKIE_STYLE base | it (default: base) +# TALKIE_PROMPT single-shot override +# TALKIE_SEQLEN single-shot override +# TALKIE_DTYPE bf16 | fp16 | fp32 (default: bf16) +# TALKIE_CKPT fall back to talkie's native ckpt loader (slow, +# transient ~100 GB RAM peak). Default is to load +# bf16 safetensors from $TALKIE_DIR/model.safetensors. +# TALKIE_INCLUDE_NOCACHE=1 +# also run catgrad without KV cache and diff. This +# tests catgrad-internal consistency rather than +# talkie correctness; in bf16 the no-cache path +# drifts by one argmax tie-break for some prompts +# and is much slower (O(N²) per step). Off by default. + +set -euo pipefail + +: "${TALKIE_DIR:?set TALKIE_DIR to the converted model directory}" +: "${TALKIE_VOCAB:?set TALKIE_VOCAB to the original vocab.txt path}" +: "${TALKIE_VENV:?set TALKIE_VENV to a python venv with talkie installed}" + +STYLE="${TALKIE_STYLE:-base}" +DTYPE="${TALKIE_DTYPE:-bf16}" +SAFETENSORS="$TALKIE_DIR/model.safetensors" + +DIR=$(dirname "$0") +SCRIPTS=$(cd "$DIR/.." && pwd) +WORKSPACE=$(cd "$SCRIPTS/../.." && pwd) +LLAMA_BIN="$WORKSPACE/target/release/examples/llama" + +[[ -x "$LLAMA_BIN" ]] || { + echo "build the example first:" >&2 + echo " cargo build --release --features metal --example llama -p catgrad-llm" >&2 + exit 1 +} + +# Build the test-case matrix. Each line is a JSON object emitted to a +# temp file. The Python ref process consumes them as JSONL, the bash +# loop iterates the same list for catgrad runs and diffing. +build_cases() { + local cases_jsonl="$1" + local out_dir="$2" + : > "$cases_jsonl" + for prompt in "${PROMPTS[@]}"; do + for seqlen in "${SEQLENS[@]}"; do + local key + key=$(printf '%s' "${prompt}__${seqlen}" | tr -c 'A-Za-z0-9' '_' | cut -c1-60) + python3 -c "import json,sys;print(json.dumps({'prompt':sys.argv[1],'seq_len':int(sys.argv[2]),'out':sys.argv[3]}))" \ + "$prompt" "$seqlen" "$out_dir/ref__$key" >> "$cases_jsonl" + done + done +} + +run_python_ref_batch() { + local cases_jsonl="$1" loader_args + if [[ -n "${TALKIE_CKPT:-}" ]]; then + loader_args=(--ckpt "$TALKIE_CKPT") + else + loader_args=(--safetensors "$SAFETENSORS") + fi + # shellcheck disable=SC1091 + source "$TALKIE_VENV/bin/activate" + python "$SCRIPTS/llm_talkie.py" \ + "${loader_args[@]}" --vocab "$TALKIE_VOCAB" --style "$STYLE" \ + --dtype "$DTYPE" --batch < "$cases_jsonl" +} + +cat_dtype() { + # Python side uses bf16/fp16/fp32; the llama example uses bf16/f16/f32. + case "$1" in + fp32) echo "f32" ;; + fp16) echo "f16" ;; + *) echo "$1" ;; + esac +} + +run_cat() { + local prompt="$1" seqlen="$2" out="$3" extra_flags="$4" + local cat_dt; cat_dt=$(cat_dtype "$DTYPE") + # shellcheck disable=SC2086 + if ! "$LLAMA_BIN" -m "$TALKIE_DIR" --raw $extra_flags --dtype "$cat_dt" \ + -p "$prompt" -s "$seqlen" > "$out" 2>"$out.err"; then + echo "[cat] llama failed for prompt=$prompt seqlen=$seqlen:" >&2 + sed 's/^/ /' < "$out.err" >&2 + return 1 + fi +} + +# Compare two outputs token-by-token (using the talkie tokenizer so +# token boundaries are real, not whitespace). Prints one of: +# +# ok — full match +# drift TOK/TOTAL CTX — diverged at token TOK out of TOTAL; CTX shows +# the first ref-vs-cat tokens after the split +# +# Used both to call PASS/FAIL and to surface the matched-prefix length, +# which is the actual signal in cross-implementation bf16: 100% prefix +# match is "exact"; partial match quantifies how far before noise flips +# a borderline argmax. +diff_summary() { + local a="$1" b="$2" + TALKIE_VOCAB="$TALKIE_VOCAB" TALKIE_STYLE="$STYLE" \ + "$TALKIE_VENV/bin/python" - "$a" "$b" <<'PY' +import os, pathlib, sys +from talkie.tokenizer import build_tokenizer +tk = build_tokenizer(os.environ["TALKIE_VOCAB"], style=os.environ.get("TALKIE_STYLE", "base")) +a = pathlib.Path(sys.argv[1]).read_text().rstrip("\n") +b = pathlib.Path(sys.argv[2]).read_text().rstrip("\n") +ta = tk.encode(a, allowed_special="all") +tb = tk.encode(b, allowed_special="all") +n = min(len(ta), len(tb)) +i = 0 +while i < n and ta[i] == tb[i]: + i += 1 +total = max(len(ta), len(tb)) +if i == len(ta) == len(tb): + print("ok") +else: + a_tail = tk.decode(ta[i : i + 5]) + b_tail = tk.decode(tb[i : i + 5]) + print(f"drift {i}/{total} ref={a_tail!r} cat={b_tail!r}") +PY +} + +run_matrix() { + local out_dir cases_jsonl + out_dir=$(mktemp -d) + cases_jsonl="$out_dir/cases.jsonl" + trap 'rm -rf "$out_dir"' RETURN + + build_cases "$cases_jsonl" "$out_dir" + + echo "Talkie stability matrix — dtype=$DTYPE style=$STYLE" + echo "ref: $([[ -n "${TALKIE_CKPT:-}" ]] && echo "$TALKIE_CKPT" || echo "$SAFETENSORS")" + echo "cat: $TALKIE_DIR" + echo "cases: $(wc -l <"$cases_jsonl") (${#PROMPTS[@]} prompts × ${#SEQLENS[@]} lens)" + echo + + # Phase 1: Python ref. One load, all cases. Slowest part. + echo "[ref] generating all reference outputs (one model load) …" + local t0; t0=$(date +%s) + run_python_ref_batch "$cases_jsonl" + echo "[ref] done in $(( $(date +%s) - t0 ))s" + echo + + # Phase 2: catgrad cat-k for every case. Optional cat-nok if requested + # (it's much slower and is a catgrad-internal consistency check, not a + # talkie-correctness one). + echo "[cat] running catgrad cached$([[ -n "${TALKIE_INCLUDE_NOCACHE:-}" ]] && echo " + uncached") for each case …" + local total_matched=0 total_tokens=0 + while IFS= read -r line; do + local prompt seqlen ref out_k out_nok + prompt=$(python3 -c "import json,sys;print(json.loads(sys.argv[1])['prompt'])" "$line") + seqlen=$(python3 -c "import json,sys;print(json.loads(sys.argv[1])['seq_len'])" "$line") + ref=$(python3 -c "import json,sys;print(json.loads(sys.argv[1])['out'])" "$line") + out_k="${ref/ref__/catk__}" + run_cat "$prompt" "$seqlen" "$out_k" "-k" + + local sum_k sum_nok="" + sum_k=$(diff_summary "$ref" "$out_k") + + if [[ -n "${TALKIE_INCLUDE_NOCACHE:-}" ]]; then + out_nok="${ref/ref__/catn__}" + run_cat "$prompt" "$seqlen" "$out_nok" "" + sum_nok=$(diff_summary "$out_k" "$out_nok") + fi + + # Accumulate token-level stability stats for the headline number. + if [[ "$sum_k" = "ok" ]]; then + total_matched=$((total_matched + seqlen)) + total_tokens=$((total_tokens + seqlen)) + else + local matched=${sum_k#drift }; matched=${matched%%/*} + local denom=${sum_k#drift *\/}; denom=${denom%% *} + total_matched=$((total_matched + matched)) + total_tokens=$((total_tokens + denom)) + fi + + local row + row=$(printf "%-50s s=%-3d " "$prompt" "$seqlen") + if [[ "$sum_k" = "ok" && ( -z "$sum_nok" || "$sum_nok" = "ok" ) ]]; then + printf "%sPASS\n" "$row" + elif [[ "$sum_k" = "ok" ]]; then + printf "%sPASS (ref==cat-k); cat-k != cat-nok: %s\n" "$row" "$sum_nok" + DRIFT=$((DRIFT + 1)) + else + printf "%sFAIL (ref != cat-k): %s\n" "$row" "$sum_k" + [[ -n "$sum_nok" && "$sum_nok" != "ok" ]] && printf " cat-k != cat-nok: %s\n" "$sum_nok" + FAILED=$((FAILED + 1)) + fi + done < "$cases_jsonl" + + local pct=0 + [[ "$total_tokens" -gt 0 ]] && pct=$(( total_matched * 1000 / total_tokens )) + echo + if [[ -n "${TALKIE_INCLUDE_NOCACHE:-}" ]]; then + printf "summary: %d cases, %d ref-mismatches, %d cat-k/cat-nok drifts; token match %d/%d (%d.%d%%)\n" \ + "$(wc -l <"$cases_jsonl")" "$FAILED" "$DRIFT" \ + "$total_matched" "$total_tokens" "$((pct / 10))" "$((pct % 10))" + else + printf "summary: %d cases, %d ref-mismatches; token match %d/%d (%d.%d%%)\n" \ + "$(wc -l <"$cases_jsonl")" "$FAILED" \ + "$total_matched" "$total_tokens" "$((pct / 10))" "$((pct % 10))" + fi + return "$FAILED" +} + +# --------------------------------------------------------------------------- +# Single-shot mode +# --------------------------------------------------------------------------- + +if [[ -n "${TALKIE_PROMPT:-}" || -n "${TALKIE_SEQLEN:-}" ]]; then + PROMPTS=("${TALKIE_PROMPT:-Once upon a time}") + SEQLENS=("${TALKIE_SEQLEN:-40}") +else + # --------------------------------------------------------------------------- + # Default matrix: variety of prompts × two lengths. + # --------------------------------------------------------------------------- + PROMPTS=( + "The quick brown fox" + "Once upon a time" + "Category theory is" + "If scientists discover life on other planets," + 'Mr. Carnegie said: "I am' + ) + # Talkie's forward has no KV cache, so the Python ref is O(N²). + # Two lengths balance "do we match at all" against "does drift appear + # over many tokens" — 40 is enough that any sub-ulp bf16 disagreement + # on the top two logits will eventually pick a different argmax. + SEQLENS=(20 40) +fi + +FAILED=0 +DRIFT=0 +run_matrix diff --git a/catgrad-llm/scripts/convert_talkie.py b/catgrad-llm/scripts/convert_talkie.py new file mode 100644 index 00000000..2062385f --- /dev/null +++ b/catgrad-llm/scripts/convert_talkie.py @@ -0,0 +1,288 @@ +"""Convert a Talkie release into a catgrad-loadable HF-style directory. + +Input: a Talkie release as published on HuggingFace + (https://huggingface.co/talkie-lm/talkie-1930-13b-{base,it}): + - final.ckpt / rl-refined.pt / base.ckpt (raw torch.save dict) + - vocab.txt (tiktoken BPE) + +Output: a directory catgrad-llm's loader can consume: + - model.safetensors (bf16 weights, talkie-native key names) + - config.json ({"architectures": ["TalkieForCausalLM"], ...}) + - tokenizer.json (HF tokenizers, byte-level BPE) + - tokenizer_config.json (chat template for IT variants) + +Dependencies: torch, safetensors, tiktoken, tokenizers. + +Usage: + python convert_talkie.py \ + --ckpt /path/to/final.ckpt \ + --vocab /path/to/vocab.txt \ + --out /path/to/talkie-1930-13b-base \ + --style base +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import torch +from safetensors.torch import save_file +from tiktoken.load import load_tiktoken_bpe +from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers + + +BASE_VOCAB_SIZE = 65536 + +# Matches src/talkie/tokenizer.py — the tiktoken pat_str, joined with `|`. +PAT_STR = "|".join( + [ + r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?", + r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?", + r"\p{N}{1,3}", + r" ?[^\s\p{L}\p{N}]+[\r\n/]*", + r"\s*[\r\n]+", + r"\s+(?!\S)", + r"\s+", + ] +) + +BASE_SPECIAL_TOKENS = {"<|endoftext|>": BASE_VOCAB_SIZE - 1} +IT_SPECIAL_TOKENS = { + "<|endoftext|>": BASE_VOCAB_SIZE - 1, + "<|end|>": BASE_VOCAB_SIZE, + "<|user|>": BASE_VOCAB_SIZE + 1, + "<|assistant|>": BASE_VOCAB_SIZE + 2, + "<|system|>": BASE_VOCAB_SIZE + 3, +} + +CHAT_TEMPLATE = ( + "{%- for message in messages -%}" + "<|{{ message.role }}|>{{ message.content }}<|end|>" + "{%- endfor -%}" + "{%- if add_generation_prompt -%}<|assistant|>{%- endif -%}" +) + +DTYPES = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} + + +def convert_weights(ckpt_path: Path, out_path: Path, dtype: torch.dtype) -> dict: + """Pickle → safetensors. Preserves Talkie's native param names verbatim + (`embed.weight`, `blocks.{i}.attn.attn_query.weight`, …, `lm_head`, + `lm_head_gain.w_g`); strips `_orig_mod.` from torch.compile. + + Returns the inferred config dict so the caller can write `config.json`. + """ + raw = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if isinstance(raw, dict) and "model_state_dict" in raw: + sd = raw["model_state_dict"] + elif isinstance(raw, dict) and "model" in raw: + sd = raw["model"] + else: + sd = raw + sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()} + + # Sanity check: the keys we expect. + required = {"embed.weight", "lm_head", "lm_head_gain.w_g"} + missing = required - sd.keys() + if missing: + raise SystemExit(f"checkpoint missing keys: {sorted(missing)}") + + vocab_size, n_embd = sd["embed.weight"].shape + n_layer = 1 + max( + int(k.split(".")[1]) for k in sd if k.startswith("blocks.") + ) + n_head = sd["blocks.0.attn.head_gain.head_g"].shape[0] + head_dim = n_embd // n_head + + # Cast in place to halve peak memory: 13B fp32 + 13B bf16 ≈ 80 GB if held + # together; popping each fp32 tensor as we cast keeps it near 30 GB. + converted = {} + for k in list(sd.keys()): + converted[k] = sd.pop(k).to(dtype).contiguous() + save_file(converted, out_path) + + return { + "vocab_size": int(vocab_size), + "hidden_size": int(n_embd), + "num_hidden_layers": int(n_layer), + "num_attention_heads": int(n_head), + "head_dim": int(head_dim), + } + + +def write_config(out_dir: Path, weights_meta: dict, style: str) -> None: + """Synthesize config.json. Talkie has no native config — everything + is hardcoded in src/talkie/model.py's GPTConfig + the call to + `_precompute_rotary_embeddings(base=1_000_000)`. + """ + eos_id = ( + IT_SPECIAL_TOKENS["<|end|>"] if style == "it" + else BASE_SPECIAL_TOKENS["<|endoftext|>"] + ) + config = { + "architectures": ["TalkieForCausalLM"], + "model_type": "talkie", + **weights_meta, + "max_position_embeddings": 2048, + "rope_theta": 1_000_000.0, + "rms_norm_eps": 1e-6, # F.rms_norm in talkie uses dtype-default; small ε is closest in fp32 + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "eos_token_id": eos_id, + } + if style == "it": + # IT model expanded the embedding to 65540 during fine-tuning. + config["vocab_size"] = max(config["vocab_size"], BASE_VOCAB_SIZE + 4) + + (out_dir / "config.json").write_text(json.dumps(config, indent=2)) + + +def _bytes_to_unicode() -> dict[int, str]: + """GPT-2 byte→unicode mapping. Reversibly maps every byte to a printable + Unicode codepoint so byte-level BPE can be expressed as a string-keyed + vocab + merges file.""" + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(256): + if b not in bs: + bs.append(b) + cs.append(256 + n) + n += 1 + return dict(zip(bs, [chr(c) for c in cs])) + + +def _bpe_split(token: bytes, ranks: dict[bytes, int]) -> tuple[bytes, bytes]: + """Recover the (left, right) merge that produced `token` by replaying the + BPE algorithm with all ranks strictly less than `ranks[token]`. + Standard recipe — same one transformers' TikTokenConverter uses.""" + parts = [bytes([b]) for b in token] + target = ranks[token] + while True: + min_idx, min_rank = None, None + for i in range(len(parts) - 1): + r = ranks.get(parts[i] + parts[i + 1]) + if r is not None and r < target and (min_rank is None or r < min_rank): + min_idx, min_rank = i, r + if min_idx is None: + break + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] + if len(parts) != 2: + raise ValueError(f"could not reduce token {token!r} (rank {target}) to a 2-part merge") + return parts[0], parts[1] + + +def write_tokenizer(vocab_path: Path, out_dir: Path, style: str) -> None: + """Convert tiktoken vocab.txt → HF tokenizer.json (byte-level BPE).""" + ranks = load_tiktoken_bpe(str(vocab_path)) + # Talkie drops the highest rank (reserved for <|endoftext|>). + ranks = {tok: r for tok, r in ranks.items() if r < BASE_VOCAB_SIZE - 1} + + b2u = _bytes_to_unicode() + + def encode(b: bytes) -> str: + return "".join(b2u[x] for x in b) + + vocab = {encode(tok): r for tok, r in ranks.items()} + merges = [] + for tok, _ in sorted(ranks.items(), key=lambda kv: kv[1]): + if len(tok) == 1: + continue + left, right = _bpe_split(tok, ranks) + merges.append((encode(left), encode(right))) + + tk = Tokenizer(models.BPE(vocab=vocab, merges=merges, fuse_unk=False, byte_fallback=False)) + tk.pre_tokenizer = pre_tokenizers.Sequence([ + pre_tokenizers.Split(pattern=Regex(PAT_STR), behavior="isolated"), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), + ]) + tk.decoder = decoders.ByteLevel() + + specials = IT_SPECIAL_TOKENS if style == "it" else BASE_SPECIAL_TOKENS + tk.add_special_tokens(list(specials.keys())) + + tk.save(str(out_dir / "tokenizer.json")) + + tcfg = { + "model_max_length": 2048, + "bos_token": None, + "eos_token": "<|end|>" if style == "it" else "<|endoftext|>", + "pad_token": None, + "added_tokens_decoder": { + str(idx): {"content": tok, "special": True} + for tok, idx in specials.items() + }, + } + if style == "it": + tcfg["chat_template"] = CHAT_TEMPLATE + (out_dir / "tokenizer_config.json").write_text(json.dumps(tcfg, indent=2)) + + +def verify_tokenizer(vocab_path: Path, out_dir: Path, style: str) -> None: + """Round-trip a sample through both tiktoken and the converted HF tokenizer. + Failures here mean the BPE reconstruction or pre-tokenizer regex diverged — + do not ship without this passing.""" + import tiktoken + + ranks = load_tiktoken_bpe(str(vocab_path)) + ranks = {tok: r for tok, r in ranks.items() if r < BASE_VOCAB_SIZE - 1} + specials = IT_SPECIAL_TOKENS if style == "it" else BASE_SPECIAL_TOKENS + tt = tiktoken.Encoding( + name="talkie", pat_str=PAT_STR, mergeable_ranks=ranks, special_tokens=specials + ) + hf = Tokenizer.from_file(str(out_dir / "tokenizer.json")) + + samples = [ + "Hello, world!", + "The quick brown fox jumps over the lazy dog.", + "It's a fine day in 1930.\nLet us discuss aeronautics.", + " leading spaces and\ttabs\n\nand newlines", + "café naïve résumé — em-dash and ellipsis…", + ] + if style == "it": + samples.append("<|user|>hi<|end|><|assistant|>") + + for s in samples: + a = tt.encode(s, allowed_special=set(specials)) + b = hf.encode(s).ids + if a != b: + raise SystemExit( + f"tokenizer mismatch on {s!r}\n tiktoken: {a}\n hf: {b}" + ) + print(f"tokenizer round-trip ok ({len(samples)} samples)") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) + ap.add_argument("--ckpt", type=Path, required=True) + ap.add_argument("--vocab", type=Path, required=True) + ap.add_argument("--out", type=Path, required=True) + ap.add_argument("--style", choices=["base", "it"], required=True) + ap.add_argument("--dtype", choices=DTYPES.keys(), default="bf16") + args = ap.parse_args() + + args.out.mkdir(parents=True, exist_ok=True) + + print("converting weights…") + meta = convert_weights(args.ckpt, args.out / "model.safetensors", DTYPES[args.dtype]) + + print("writing config.json…") + write_config(args.out, meta, args.style) + + print("writing tokenizer.json + tokenizer_config.json…") + write_tokenizer(args.vocab, args.out, args.style) + + print("verifying tokenizer round-trip…") + verify_tokenizer(args.vocab, args.out, args.style) + + print(f"done → {args.out}") + + +if __name__ == "__main__": + main() diff --git a/catgrad-llm/scripts/llm_talkie.py b/catgrad-llm/scripts/llm_talkie.py new file mode 100644 index 00000000..c76a1c48 --- /dev/null +++ b/catgrad-llm/scripts/llm_talkie.py @@ -0,0 +1,150 @@ +"""Greedy-decode reference for Talkie. Mirrors `scripts/llm.py`'s +`do_sample=False` path but uses the talkie package (talkie isn't in +HF transformers, so AutoModelForCausalLM doesn't apply). + +Output is true argmax (no Gumbel sampling, no temperature) so the +result is byte-identical across runs and directly comparable to +catgrad-llm's argmax-greedy decode. + +Two loader paths: + + * `--ckpt` — talkie's native `final.ckpt`. Matches upstream + load semantics, but `torch.load` materialises the + whole 53 GB fp32 dict in RAM before casting to + bf16 (transient peak ≈ 100 GB). + * `--safetensors` — pre-converted bf16 safetensors (output of + `convert_talkie.py`). Peak ≈ 52 GB. + +Both produce bit-identical model state once cast, so use safetensors +for the stability harness. + +Two run modes: + + * single-shot — `--prompt P --seq-len N`, output goes to `--out` or stdout. + * batch — `--batch`, reads JSON-line cases from stdin + (`{"prompt": "...", "seq_len": N, "out": "/path"}`), + loads the model once, runs all cases. The matrix harness + uses this so we don't pay the 50 GB load N times. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +import torch + +from talkie.model import GPTConfig, TalkieModel, resize_model_embeddings +from talkie.tokenizer import IT_VOCAB_SIZE, build_tokenizer + + +DTYPES = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} + + +def load_from_safetensors( + path: Path, device: torch.device, target_vocab_size: int | None +) -> TalkieModel: + """Build a TalkieModel from our pre-converted bf16 safetensors. + + Skips the ckpt → fp32-dict → cast round-trip in talkie's own loader, + which would otherwise spike to ~100 GB on a 13B model. + """ + from safetensors.torch import load_file + + sd = load_file(str(path), device="cpu") + vocab_size = sd["embed.weight"].shape[0] + config = GPTConfig(vocab_size=vocab_size) + + cpu = torch.device("cpu") + model = TalkieModel(config, cpu) + model.load_state_dict(sd, strict=True) + del sd + + if target_vocab_size is not None and vocab_size < target_vocab_size: + model = resize_model_embeddings(model, target_vocab_size, cpu) + + model = model.to(dtype=torch.bfloat16).to(device) + model.device = device + model.eval() + return model + + +def load_model(args, device: torch.device) -> TalkieModel: + target_vocab = IT_VOCAB_SIZE if args.style == "it" else None + if args.safetensors is not None: + model = load_from_safetensors(args.safetensors, device, target_vocab) + else: + from talkie.model import load_checkpoint + model = load_checkpoint(str(args.ckpt), device, target_vocab_size=target_vocab) + if DTYPES[args.dtype] != torch.bfloat16: + model = model.to(dtype=DTYPES[args.dtype]) + model.eval() + return model + + +def generate(model: TalkieModel, tokenizer, prompt: str, seq_len: int) -> str: + ids = tokenizer.encode(prompt, allowed_special="all") + x = torch.tensor([ids], device=model.device, dtype=torch.long) + with torch.no_grad(): + for _ in range(seq_len): + logits = model.forward(x) + next_id = int(torch.argmax(logits, dim=-1).item()) + x = torch.cat( + [x, torch.tensor([[next_id]], device=model.device, dtype=torch.long)], + dim=1, + ) + return tokenizer.decode(x[0].tolist()) + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) + src = ap.add_mutually_exclusive_group(required=True) + src.add_argument("--ckpt", type=Path, help="talkie native final.ckpt (slow load)") + src.add_argument("--safetensors", type=Path, help="bf16 safetensors (fast load)") + ap.add_argument("--vocab", type=Path, required=True) + ap.add_argument("--style", choices=["base", "it"], default="base") + ap.add_argument("--dtype", choices=DTYPES.keys(), default="bf16") + ap.add_argument("--device", default=None, + help="Defaults to cuda if available, otherwise cpu.") + ap.add_argument("-p", "--prompt", help="single-shot prompt") + ap.add_argument("-s", "--seq-len", type=int, default=20) + ap.add_argument("-o", "--out", type=Path, help="single-shot output file (default stdout)") + ap.add_argument("--batch", action="store_true", + help="Read JSONL cases from stdin (one {prompt, seq_len, out} per line).") + args = ap.parse_args() + + if args.batch and args.prompt is not None: + ap.error("--batch and --prompt are mutually exclusive") + if not args.batch and args.prompt is None: + ap.error("either --batch or --prompt is required") + + device = torch.device( + args.device or ("cuda" if torch.cuda.is_available() else "cpu") + ) + print(f"loading talkie on {device} dtype={args.dtype}…", file=sys.stderr, flush=True) + tokenizer = build_tokenizer(str(args.vocab), style=args.style) + model = load_model(args, device) + print("loaded.", file=sys.stderr, flush=True) + + if args.batch: + for raw in sys.stdin: + line = raw.strip() + if not line: + continue + case = json.loads(line) + text = generate(model, tokenizer, case["prompt"], int(case["seq_len"])) + Path(case["out"]).write_text(text + "\n") + print(f" case {case['prompt']!r} s={case['seq_len']} → {case['out']}", + file=sys.stderr, flush=True) + else: + text = generate(model, tokenizer, args.prompt, args.seq_len) + if args.out: + args.out.write_text(text + "\n") + else: + print(text) + + +if __name__ == "__main__": + main() diff --git a/catgrad-llm/src/models/mod.rs b/catgrad-llm/src/models/mod.rs index 959bcb32..85e8174d 100644 --- a/catgrad-llm/src/models/mod.rs +++ b/catgrad-llm/src/models/mod.rs @@ -14,3 +14,4 @@ pub mod qwen3; pub mod qwen3_5; pub mod siglip; pub mod smolvlm2; +pub mod talkie; diff --git a/catgrad-llm/src/models/talkie.rs b/catgrad-llm/src/models/talkie.rs new file mode 100644 index 00000000..35b8d7e0 --- /dev/null +++ b/catgrad-llm/src/models/talkie.rs @@ -0,0 +1,369 @@ +//! Talkie 13B — a decoder-only transformer with the standard Llama backbone +//! plus four small departures: +//! +//! 1. RMSNorm everywhere is **unweighted** (`F.rms_norm(x, ..)` with no γ), +//! including a norm immediately after the embedding. +//! 2. **QK-norm** — RMSNorm is applied to Q and K *after* RoPE. +//! 3. **Per-head and per-layer learned gains** — `head_gain` (shape `[H]`) +//! on Q after QK-norm, and scalar `attn_gain` / `mlp_gain` on the +//! attention and MLP residual branches. +//! 4. **Embedding skip connection** — the post-input-norm activations are +//! threaded through every block as `e_x` and added back via a scalar +//! `embed_skip` gain. +//! +//! The lm_head is an untied `[V, D]` parameter (not a `Linear`) and is +//! scaled by a learned scalar (`lm_head_gain.w_g`) before the final matmul. +//! +//! RoPE convention differs from catgrad's default by a sign on `sin` +//! (talkie: `y1 = x1·cos + x2·sin`, catgrad: `y1 = x1·cos - x2·sin`); we +//! negate `cache.sin` once after init to match. + +#![allow(clippy::too_many_arguments)] + +use crate::config::{EosTokenId, LLMConfig}; +use crate::helpers::*; +use catgrad::prelude::ops::*; +use catgrad::prelude::*; +use nn::*; +use serde::Deserialize; + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +pub struct TalkieConfig { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + head_dim: usize, + max_position_embeddings: usize, + rope_theta: f32, + rms_norm_eps: f32, + tie_word_embeddings: bool, + eos_token_id: Option, +} + +impl LLMConfig for TalkieConfig { + fn num_hidden_layers(&self) -> usize { + self.num_hidden_layers + } + + fn num_key_value_heads(&self) -> usize { + // Talkie has no GQA — kv heads == attention heads. + self.num_attention_heads + } + + fn rope_theta(&self) -> f32 { + self.rope_theta + } + + fn max_position_embeddings(&self) -> usize { + self.max_position_embeddings + } + + fn get_head_dim(&self) -> usize { + self.head_dim + } + + fn eos_token_id(&self) -> Option { + self.eos_token_id.clone() + } +} + +impl TalkieConfig { + /// Talkie's MLP intermediate width: `round((8/3)·D / 128) · 128`. + /// Hardcoded in `src/talkie/model.py`'s `MLP.__init__`; not stored in + /// the checkpoint or any config file. + fn intermediate_size(&self) -> usize { + let h = self.hidden_size as f32; + (((8.0 / 3.0) * h / 128.0).round() as usize) * 128 + } +} + +#[derive(Debug, Clone)] +pub struct TalkieModel { + config: TalkieConfig, + dtype: Dtype, +} + +impl LLMModel for TalkieModel { + fn config(&self) -> &dyn LLMConfig { + &self.config + } + + fn dtype(&self) -> Dtype { + self.dtype + } +} + +impl TalkieModel { + pub fn new(config_json: &serde_json::Value, dtype: Dtype) -> crate::Result { + let config: TalkieConfig = serde_json::from_value(config_json.clone())?; + Ok(Self { config, dtype }) + } + + fn forward( + &self, + builder: &Builder, + p: Path, + x: Var, + in_k: Var, + in_v: Var, + max_positions: Var, + ) -> Vec { + let eps = self.config.rms_norm_eps; + + // Embed → input RMSNorm (unweighted) → save as e_x for embed-skip. + let x = embeddings(builder, p.extend(["embed"]).unwrap(), x); + let x = rmsnorm_raw::<3>(builder, eps, x); + let e_x = x.clone(); + + let [_, s, _] = unpack::<3>(builder, shape(builder, x.clone())); + let [_, _, _, pos, _] = unpack::<5>(builder, shape(builder, in_k.clone())); + let attention_mask = causal_mask(builder, s, pos.clone()); + + let mut cache = Cache::init( + builder, + &self.config, + max_positions.clone(), + max_positions, + in_k, + in_v, + ); + + // Talkie's RoPE has the opposite sin convention from catgrad's. Negate + // the sin table once here; everything downstream uses it as-is. + let neg = constant(builder, -1.0, &shape(builder, cache.sin.clone())); + cache.sin = cache.sin.clone() * neg; + + let mut x = x; + for i in 0..self.config.num_hidden_layers { + x = self.layer( + builder, + i, + attention_mask.clone(), + &mut cache, + pos.clone(), + e_x.clone(), + p.extend(["blocks", &i.to_string()]).unwrap(), + x, + ); + } + + let x = rmsnorm_raw::<3>(builder, eps, x); + + // lm_head with WeightGain: scale the [V, D] weight by a scalar + // before the matmul. lm_head is a bare Parameter, not a Linear, + // so we can't go through `linear_no_bias` (which expects `

.weight`). + let lm_head = param(builder, &p.extend(["lm_head"]).unwrap()); + let w_g = param(builder, &p.extend(["lm_head_gain", "w_g"]).unwrap()); + let lm_sh = shape(builder, lm_head.clone()); + let w_g = broadcast(builder, lm_sh, w_g); + let lm_head = lm_head * w_g; + let x = linear_no_bias_param( + builder, + self.config.hidden_size, + self.config.vocab_size, + lm_head, + x, + ); + + let x = argmax(builder, x); + let (out_k, out_v) = cache.get_kv_cache(builder); + vec![x, out_k, out_v] + } + + fn layer( + &self, + builder: &Builder, + layer_id: usize, + attention_mask: Var, + cache: &mut Cache, + pos: Var, + e_x: Var, + p: Path, + x: Var, + ) -> Var { + let eps = self.config.rms_norm_eps; + + // Pre-attn norm (unweighted) → attn → scalar attn_gain → residual. + let res = x.clone(); + let x = rmsnorm_raw::<3>(builder, eps, x); + let x = self.attention( + builder, + layer_id, + attention_mask, + cache, + pos, + p.extend(["attn"]).unwrap(), + x, + ); + let x = scale(builder, p.extend(["attn_gain", "a_g"]).unwrap(), x); + let x = res + x; + + // Pre-mlp norm (unweighted) → mlp → scalar mlp_gain → residual. + let res = x.clone(); + let x = rmsnorm_raw::<3>(builder, eps, x); + let x = self.mlp(builder, p.extend(["mlp"]).unwrap(), x); + let x = scale(builder, p.extend(["mlp_gain", "a_g"]).unwrap(), x); + let x = res + x; + + // Embedding-skip residual: x += embed_skip * e_x. + let skip = scale(builder, p.extend(["embed_skip", "a_g"]).unwrap(), e_x); + x + skip + } + + fn mlp(&self, builder: &Builder, p: Path, x: Var) -> Var { + let h = self.config.hidden_size; + let i = self.config.intermediate_size(); + let gate = linear_no_bias(builder, h, i, p.extend(["mlp_gate"]).unwrap(), x.clone()); + let up = linear_no_bias(builder, h, i, p.extend(["mlp_linear"]).unwrap(), x); + let x = silu(builder, gate) * up; + linear_no_bias(builder, i, h, p.extend(["mlp_resid"]).unwrap(), x) + } + + fn attention( + &self, + builder: &Builder, + layer_id: usize, + attention_mask: Var, + cache: &mut Cache, + pos: Var, + p: Path, + x: Var, + ) -> Var { + let dim = self.config.hidden_size; + let num_heads = self.config.num_attention_heads; + let head_dim = self.config.head_dim; + let eps = self.config.rms_norm_eps; + + let [b, s, _] = unpack::<3>(builder, shape(builder, x.clone())); + + let q = linear_no_bias( + builder, + dim, + dim, + p.extend(["attn_query"]).unwrap(), + x.clone(), + ); + let k = linear_no_bias( + builder, + dim, + dim, + p.extend(["attn_key"]).unwrap(), + x.clone(), + ); + let v = linear_no_bias(builder, dim, dim, p.extend(["attn_value"]).unwrap(), x); + + let qkv_sh = shape!(builder, b, s, num_heads, head_dim); + let q = reshape(builder, qkv_sh.clone(), q); + let k = reshape(builder, qkv_sh.clone(), k); + let v = reshape(builder, qkv_sh, v); + + let q = transpose(builder, 1, 2, q); + let k = transpose(builder, 1, 2, k); + let v = transpose(builder, 1, 2, v); + + // RoPE (cache.sin already sign-flipped at init). + let q = apply_rope_embedding( + builder, + pos.clone(), + head_dim, + cache.cos.clone(), + cache.sin.clone(), + q, + ); + let k = apply_rope_embedding( + builder, + pos, + head_dim, + cache.cos.clone(), + cache.sin.clone(), + k, + ); + + // QK-norm (RMSNorm with no learned weight, over last dim). + let q = rmsnorm_raw::<4>(builder, eps, q); + let k = rmsnorm_raw::<4>(builder, eps, k); + + // Per-head gain on Q only. head_g: [H] → broadcast to [B, H, S, D]. + let head_g = param(builder, &p.extend(["head_gain", "head_g"]).unwrap()); + let head_g = reshape(builder, shape!(builder, 1, num_heads, 1, 1), head_g); + let q_sh = shape(builder, q.clone()); + let head_g = broadcast(builder, q_sh, head_g); + let q = q * head_g; + + let (k, v) = cache.update_kv_cache(builder, layer_id, k, v); + + let tk = transpose(builder, 2, 3, k); + let attn = matmul(builder, q, tk); + let attn_sh = shape(builder, attn.clone()); + let denom = constant(builder, f32::sqrt(head_dim as f32), &attn_sh); + let denom = cast(builder, denom, dtype(builder, attn.clone())); + let mut attn = attn / denom; + + let mask = cast(builder, attention_mask, dtype(builder, attn.clone())); + let mask = broadcast(builder, attn_sh, mask); + attn = attn + mask; + + let attn = softmax(builder, attn); + let attn = matmul(builder, attn, v); + + let attn = transpose(builder, 1, 2, attn); + let attn = reshape(builder, shape!(builder, b, s, dim), attn); + + linear_no_bias(builder, dim, dim, p.extend(["attn_resid"]).unwrap(), attn) + } +} + +/// Multiply `x` by a scalar parameter at `p` (shape `[1]`), broadcasting +/// over `x`'s shape. Used for `attn_gain`, `mlp_gain`, `embed_skip`. +fn scale(builder: &Builder, p: Path, x: Var) -> Var { + let g = param(builder, &p); + let sh = shape(builder, x.clone()); + let g = broadcast(builder, sh, g); + x * g +} + +impl DynModule for TalkieModel { + fn path(&self) -> Path { + path(vec!["talkie"]).expect("invalid model path") + } + + fn def(&self, builder: &Builder, args: Vec) -> Vec { + let [x, in_k, in_v, max_positions]: [Var; 4] = args.try_into().expect("expected 4 inputs"); + self.forward(builder, self.path(), x, in_k, in_v, max_positions) + } + + fn ty(&self) -> (Vec, Vec) { + llm_type(&self.config, self.dtype()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn term_typechecks() { + // Tiny shape that still exercises every novel piece (QK-norm, + // head_gain, attn_gain, mlp_gain, embed_skip, lm_head_gain) and the + // n_mlp formula `round((8/3)·H/128)·128` (here: 192 → 512). + let cfg = serde_json::json!({ + "vocab_size": 64, + "hidden_size": 192, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "head_dim": 96, + "max_position_embeddings": 16, + "rope_theta": 1_000_000.0, + "rms_norm_eps": 1e-6, + "tie_word_embeddings": false, + "eos_token_id": 63, + }); + let model = TalkieModel::new(&cfg, Dtype::BF16).expect("model construction"); + assert_eq!(model.config.intermediate_size(), 512); + model + .term() + .expect("term construction failed (sort/type mismatch)"); + } +} From b458fda46bfa83546ed82e361e66f2535aa95205 Mon Sep 17 00:00:00 2001 From: georgewhewell Date: Wed, 29 Apr 2026 12:24:32 +0200 Subject: [PATCH 3/3] Talkie: target lewtun/talkie-1930-13b-it-hf naming, drop our converter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The decoder stack now reads from `model.embed.weight`, `model.blocks.{i}.…` — matching the HF port at `lewtun/talkie-1930-13b-it-hf` (`TalkieForCausalLM` with `self.model = TalkieModel(…)` and `lm_head`/`lm_head_gain.w_g` at the root). That repo includes a full HF-format checkpoint plus a `tokenizer.json` already in HF tokenizers form, so our pickle→safetensors converter and greedy-argmax reference are no longer needed: - rm catgrad-llm/scripts/convert_talkie.py - rm catgrad-llm/scripts/llm_talkie.py - rm catgrad-llm/scripts/compare/talkie_compare.sh End-to-end run: ./target/release/examples/llama -m lewtun/talkie-1930-13b-it-hf \ -k -s 60 --dtype bf16 -p "Write a short poem about the wireless telegraph." --- catgrad-llm/scripts/compare/talkie_compare.sh | 267 ---------------- catgrad-llm/scripts/convert_talkie.py | 288 ------------------ catgrad-llm/scripts/llm_talkie.py | 150 --------- catgrad-llm/src/models/talkie.rs | 14 +- 4 files changed, 12 insertions(+), 707 deletions(-) delete mode 100755 catgrad-llm/scripts/compare/talkie_compare.sh delete mode 100644 catgrad-llm/scripts/convert_talkie.py delete mode 100644 catgrad-llm/scripts/llm_talkie.py diff --git a/catgrad-llm/scripts/compare/talkie_compare.sh b/catgrad-llm/scripts/compare/talkie_compare.sh deleted file mode 100755 index 5ab9765c..00000000 --- a/catgrad-llm/scripts/compare/talkie_compare.sh +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env bash -# Stability harness for the Talkie port. -# -# Token-by-token diffs the catgrad implementation against the talkie -# package's PyTorch reference. The reference loader holds the model in -# RAM for the whole matrix (one load, all cases) — without that, each -# 13B Python process spikes to ~50 GB and back-to-back runs OOM. -# -# Three outputs per case: -# ref — talkie / pytorch (CPU, dtype-cast bf16 by default) -# cat-k — catgrad-llm with KV cache -# cat-nok — catgrad-llm without KV cache -# -# Comparisons: -# ref vs cat-k — does the port match the upstream reference? -# This is the headline correctness check. -# cat-k vs cat-nok — does catgrad's cache implementation match its -# uncached path? In bf16 this can drift by one -# argmax tie-break and is informative but not -# strictly a talkie-level concern. -# -# Modes: -# * Single-shot — set TALKIE_PROMPT or TALKIE_SEQLEN; runs that one case. -# * Matrix — default. Built-in suite of prompts × lengths. -# -# Required env: -# TALKIE_DIR directory with the converted model (config.json, -# model.safetensors, tokenizer.json, ...) -# TALKIE_VOCAB path to the original vocab.txt -# TALKIE_VENV path to a venv with the `talkie` package installed -# (`pip install git+https://github.com/talkie-lm/talkie`) -# -# Optional env: -# TALKIE_STYLE base | it (default: base) -# TALKIE_PROMPT single-shot override -# TALKIE_SEQLEN single-shot override -# TALKIE_DTYPE bf16 | fp16 | fp32 (default: bf16) -# TALKIE_CKPT fall back to talkie's native ckpt loader (slow, -# transient ~100 GB RAM peak). Default is to load -# bf16 safetensors from $TALKIE_DIR/model.safetensors. -# TALKIE_INCLUDE_NOCACHE=1 -# also run catgrad without KV cache and diff. This -# tests catgrad-internal consistency rather than -# talkie correctness; in bf16 the no-cache path -# drifts by one argmax tie-break for some prompts -# and is much slower (O(N²) per step). Off by default. - -set -euo pipefail - -: "${TALKIE_DIR:?set TALKIE_DIR to the converted model directory}" -: "${TALKIE_VOCAB:?set TALKIE_VOCAB to the original vocab.txt path}" -: "${TALKIE_VENV:?set TALKIE_VENV to a python venv with talkie installed}" - -STYLE="${TALKIE_STYLE:-base}" -DTYPE="${TALKIE_DTYPE:-bf16}" -SAFETENSORS="$TALKIE_DIR/model.safetensors" - -DIR=$(dirname "$0") -SCRIPTS=$(cd "$DIR/.." && pwd) -WORKSPACE=$(cd "$SCRIPTS/../.." && pwd) -LLAMA_BIN="$WORKSPACE/target/release/examples/llama" - -[[ -x "$LLAMA_BIN" ]] || { - echo "build the example first:" >&2 - echo " cargo build --release --features metal --example llama -p catgrad-llm" >&2 - exit 1 -} - -# Build the test-case matrix. Each line is a JSON object emitted to a -# temp file. The Python ref process consumes them as JSONL, the bash -# loop iterates the same list for catgrad runs and diffing. -build_cases() { - local cases_jsonl="$1" - local out_dir="$2" - : > "$cases_jsonl" - for prompt in "${PROMPTS[@]}"; do - for seqlen in "${SEQLENS[@]}"; do - local key - key=$(printf '%s' "${prompt}__${seqlen}" | tr -c 'A-Za-z0-9' '_' | cut -c1-60) - python3 -c "import json,sys;print(json.dumps({'prompt':sys.argv[1],'seq_len':int(sys.argv[2]),'out':sys.argv[3]}))" \ - "$prompt" "$seqlen" "$out_dir/ref__$key" >> "$cases_jsonl" - done - done -} - -run_python_ref_batch() { - local cases_jsonl="$1" loader_args - if [[ -n "${TALKIE_CKPT:-}" ]]; then - loader_args=(--ckpt "$TALKIE_CKPT") - else - loader_args=(--safetensors "$SAFETENSORS") - fi - # shellcheck disable=SC1091 - source "$TALKIE_VENV/bin/activate" - python "$SCRIPTS/llm_talkie.py" \ - "${loader_args[@]}" --vocab "$TALKIE_VOCAB" --style "$STYLE" \ - --dtype "$DTYPE" --batch < "$cases_jsonl" -} - -cat_dtype() { - # Python side uses bf16/fp16/fp32; the llama example uses bf16/f16/f32. - case "$1" in - fp32) echo "f32" ;; - fp16) echo "f16" ;; - *) echo "$1" ;; - esac -} - -run_cat() { - local prompt="$1" seqlen="$2" out="$3" extra_flags="$4" - local cat_dt; cat_dt=$(cat_dtype "$DTYPE") - # shellcheck disable=SC2086 - if ! "$LLAMA_BIN" -m "$TALKIE_DIR" --raw $extra_flags --dtype "$cat_dt" \ - -p "$prompt" -s "$seqlen" > "$out" 2>"$out.err"; then - echo "[cat] llama failed for prompt=$prompt seqlen=$seqlen:" >&2 - sed 's/^/ /' < "$out.err" >&2 - return 1 - fi -} - -# Compare two outputs token-by-token (using the talkie tokenizer so -# token boundaries are real, not whitespace). Prints one of: -# -# ok — full match -# drift TOK/TOTAL CTX — diverged at token TOK out of TOTAL; CTX shows -# the first ref-vs-cat tokens after the split -# -# Used both to call PASS/FAIL and to surface the matched-prefix length, -# which is the actual signal in cross-implementation bf16: 100% prefix -# match is "exact"; partial match quantifies how far before noise flips -# a borderline argmax. -diff_summary() { - local a="$1" b="$2" - TALKIE_VOCAB="$TALKIE_VOCAB" TALKIE_STYLE="$STYLE" \ - "$TALKIE_VENV/bin/python" - "$a" "$b" <<'PY' -import os, pathlib, sys -from talkie.tokenizer import build_tokenizer -tk = build_tokenizer(os.environ["TALKIE_VOCAB"], style=os.environ.get("TALKIE_STYLE", "base")) -a = pathlib.Path(sys.argv[1]).read_text().rstrip("\n") -b = pathlib.Path(sys.argv[2]).read_text().rstrip("\n") -ta = tk.encode(a, allowed_special="all") -tb = tk.encode(b, allowed_special="all") -n = min(len(ta), len(tb)) -i = 0 -while i < n and ta[i] == tb[i]: - i += 1 -total = max(len(ta), len(tb)) -if i == len(ta) == len(tb): - print("ok") -else: - a_tail = tk.decode(ta[i : i + 5]) - b_tail = tk.decode(tb[i : i + 5]) - print(f"drift {i}/{total} ref={a_tail!r} cat={b_tail!r}") -PY -} - -run_matrix() { - local out_dir cases_jsonl - out_dir=$(mktemp -d) - cases_jsonl="$out_dir/cases.jsonl" - trap 'rm -rf "$out_dir"' RETURN - - build_cases "$cases_jsonl" "$out_dir" - - echo "Talkie stability matrix — dtype=$DTYPE style=$STYLE" - echo "ref: $([[ -n "${TALKIE_CKPT:-}" ]] && echo "$TALKIE_CKPT" || echo "$SAFETENSORS")" - echo "cat: $TALKIE_DIR" - echo "cases: $(wc -l <"$cases_jsonl") (${#PROMPTS[@]} prompts × ${#SEQLENS[@]} lens)" - echo - - # Phase 1: Python ref. One load, all cases. Slowest part. - echo "[ref] generating all reference outputs (one model load) …" - local t0; t0=$(date +%s) - run_python_ref_batch "$cases_jsonl" - echo "[ref] done in $(( $(date +%s) - t0 ))s" - echo - - # Phase 2: catgrad cat-k for every case. Optional cat-nok if requested - # (it's much slower and is a catgrad-internal consistency check, not a - # talkie-correctness one). - echo "[cat] running catgrad cached$([[ -n "${TALKIE_INCLUDE_NOCACHE:-}" ]] && echo " + uncached") for each case …" - local total_matched=0 total_tokens=0 - while IFS= read -r line; do - local prompt seqlen ref out_k out_nok - prompt=$(python3 -c "import json,sys;print(json.loads(sys.argv[1])['prompt'])" "$line") - seqlen=$(python3 -c "import json,sys;print(json.loads(sys.argv[1])['seq_len'])" "$line") - ref=$(python3 -c "import json,sys;print(json.loads(sys.argv[1])['out'])" "$line") - out_k="${ref/ref__/catk__}" - run_cat "$prompt" "$seqlen" "$out_k" "-k" - - local sum_k sum_nok="" - sum_k=$(diff_summary "$ref" "$out_k") - - if [[ -n "${TALKIE_INCLUDE_NOCACHE:-}" ]]; then - out_nok="${ref/ref__/catn__}" - run_cat "$prompt" "$seqlen" "$out_nok" "" - sum_nok=$(diff_summary "$out_k" "$out_nok") - fi - - # Accumulate token-level stability stats for the headline number. - if [[ "$sum_k" = "ok" ]]; then - total_matched=$((total_matched + seqlen)) - total_tokens=$((total_tokens + seqlen)) - else - local matched=${sum_k#drift }; matched=${matched%%/*} - local denom=${sum_k#drift *\/}; denom=${denom%% *} - total_matched=$((total_matched + matched)) - total_tokens=$((total_tokens + denom)) - fi - - local row - row=$(printf "%-50s s=%-3d " "$prompt" "$seqlen") - if [[ "$sum_k" = "ok" && ( -z "$sum_nok" || "$sum_nok" = "ok" ) ]]; then - printf "%sPASS\n" "$row" - elif [[ "$sum_k" = "ok" ]]; then - printf "%sPASS (ref==cat-k); cat-k != cat-nok: %s\n" "$row" "$sum_nok" - DRIFT=$((DRIFT + 1)) - else - printf "%sFAIL (ref != cat-k): %s\n" "$row" "$sum_k" - [[ -n "$sum_nok" && "$sum_nok" != "ok" ]] && printf " cat-k != cat-nok: %s\n" "$sum_nok" - FAILED=$((FAILED + 1)) - fi - done < "$cases_jsonl" - - local pct=0 - [[ "$total_tokens" -gt 0 ]] && pct=$(( total_matched * 1000 / total_tokens )) - echo - if [[ -n "${TALKIE_INCLUDE_NOCACHE:-}" ]]; then - printf "summary: %d cases, %d ref-mismatches, %d cat-k/cat-nok drifts; token match %d/%d (%d.%d%%)\n" \ - "$(wc -l <"$cases_jsonl")" "$FAILED" "$DRIFT" \ - "$total_matched" "$total_tokens" "$((pct / 10))" "$((pct % 10))" - else - printf "summary: %d cases, %d ref-mismatches; token match %d/%d (%d.%d%%)\n" \ - "$(wc -l <"$cases_jsonl")" "$FAILED" \ - "$total_matched" "$total_tokens" "$((pct / 10))" "$((pct % 10))" - fi - return "$FAILED" -} - -# --------------------------------------------------------------------------- -# Single-shot mode -# --------------------------------------------------------------------------- - -if [[ -n "${TALKIE_PROMPT:-}" || -n "${TALKIE_SEQLEN:-}" ]]; then - PROMPTS=("${TALKIE_PROMPT:-Once upon a time}") - SEQLENS=("${TALKIE_SEQLEN:-40}") -else - # --------------------------------------------------------------------------- - # Default matrix: variety of prompts × two lengths. - # --------------------------------------------------------------------------- - PROMPTS=( - "The quick brown fox" - "Once upon a time" - "Category theory is" - "If scientists discover life on other planets," - 'Mr. Carnegie said: "I am' - ) - # Talkie's forward has no KV cache, so the Python ref is O(N²). - # Two lengths balance "do we match at all" against "does drift appear - # over many tokens" — 40 is enough that any sub-ulp bf16 disagreement - # on the top two logits will eventually pick a different argmax. - SEQLENS=(20 40) -fi - -FAILED=0 -DRIFT=0 -run_matrix diff --git a/catgrad-llm/scripts/convert_talkie.py b/catgrad-llm/scripts/convert_talkie.py deleted file mode 100644 index 2062385f..00000000 --- a/catgrad-llm/scripts/convert_talkie.py +++ /dev/null @@ -1,288 +0,0 @@ -"""Convert a Talkie release into a catgrad-loadable HF-style directory. - -Input: a Talkie release as published on HuggingFace - (https://huggingface.co/talkie-lm/talkie-1930-13b-{base,it}): - - final.ckpt / rl-refined.pt / base.ckpt (raw torch.save dict) - - vocab.txt (tiktoken BPE) - -Output: a directory catgrad-llm's loader can consume: - - model.safetensors (bf16 weights, talkie-native key names) - - config.json ({"architectures": ["TalkieForCausalLM"], ...}) - - tokenizer.json (HF tokenizers, byte-level BPE) - - tokenizer_config.json (chat template for IT variants) - -Dependencies: torch, safetensors, tiktoken, tokenizers. - -Usage: - python convert_talkie.py \ - --ckpt /path/to/final.ckpt \ - --vocab /path/to/vocab.txt \ - --out /path/to/talkie-1930-13b-base \ - --style base -""" - -from __future__ import annotations - -import argparse -import json -from pathlib import Path - -import torch -from safetensors.torch import save_file -from tiktoken.load import load_tiktoken_bpe -from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers - - -BASE_VOCAB_SIZE = 65536 - -# Matches src/talkie/tokenizer.py — the tiktoken pat_str, joined with `|`. -PAT_STR = "|".join( - [ - r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?", - r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?", - r"\p{N}{1,3}", - r" ?[^\s\p{L}\p{N}]+[\r\n/]*", - r"\s*[\r\n]+", - r"\s+(?!\S)", - r"\s+", - ] -) - -BASE_SPECIAL_TOKENS = {"<|endoftext|>": BASE_VOCAB_SIZE - 1} -IT_SPECIAL_TOKENS = { - "<|endoftext|>": BASE_VOCAB_SIZE - 1, - "<|end|>": BASE_VOCAB_SIZE, - "<|user|>": BASE_VOCAB_SIZE + 1, - "<|assistant|>": BASE_VOCAB_SIZE + 2, - "<|system|>": BASE_VOCAB_SIZE + 3, -} - -CHAT_TEMPLATE = ( - "{%- for message in messages -%}" - "<|{{ message.role }}|>{{ message.content }}<|end|>" - "{%- endfor -%}" - "{%- if add_generation_prompt -%}<|assistant|>{%- endif -%}" -) - -DTYPES = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} - - -def convert_weights(ckpt_path: Path, out_path: Path, dtype: torch.dtype) -> dict: - """Pickle → safetensors. Preserves Talkie's native param names verbatim - (`embed.weight`, `blocks.{i}.attn.attn_query.weight`, …, `lm_head`, - `lm_head_gain.w_g`); strips `_orig_mod.` from torch.compile. - - Returns the inferred config dict so the caller can write `config.json`. - """ - raw = torch.load(ckpt_path, map_location="cpu", weights_only=False) - if isinstance(raw, dict) and "model_state_dict" in raw: - sd = raw["model_state_dict"] - elif isinstance(raw, dict) and "model" in raw: - sd = raw["model"] - else: - sd = raw - sd = {k.removeprefix("_orig_mod."): v for k, v in sd.items()} - - # Sanity check: the keys we expect. - required = {"embed.weight", "lm_head", "lm_head_gain.w_g"} - missing = required - sd.keys() - if missing: - raise SystemExit(f"checkpoint missing keys: {sorted(missing)}") - - vocab_size, n_embd = sd["embed.weight"].shape - n_layer = 1 + max( - int(k.split(".")[1]) for k in sd if k.startswith("blocks.") - ) - n_head = sd["blocks.0.attn.head_gain.head_g"].shape[0] - head_dim = n_embd // n_head - - # Cast in place to halve peak memory: 13B fp32 + 13B bf16 ≈ 80 GB if held - # together; popping each fp32 tensor as we cast keeps it near 30 GB. - converted = {} - for k in list(sd.keys()): - converted[k] = sd.pop(k).to(dtype).contiguous() - save_file(converted, out_path) - - return { - "vocab_size": int(vocab_size), - "hidden_size": int(n_embd), - "num_hidden_layers": int(n_layer), - "num_attention_heads": int(n_head), - "head_dim": int(head_dim), - } - - -def write_config(out_dir: Path, weights_meta: dict, style: str) -> None: - """Synthesize config.json. Talkie has no native config — everything - is hardcoded in src/talkie/model.py's GPTConfig + the call to - `_precompute_rotary_embeddings(base=1_000_000)`. - """ - eos_id = ( - IT_SPECIAL_TOKENS["<|end|>"] if style == "it" - else BASE_SPECIAL_TOKENS["<|endoftext|>"] - ) - config = { - "architectures": ["TalkieForCausalLM"], - "model_type": "talkie", - **weights_meta, - "max_position_embeddings": 2048, - "rope_theta": 1_000_000.0, - "rms_norm_eps": 1e-6, # F.rms_norm in talkie uses dtype-default; small ε is closest in fp32 - "tie_word_embeddings": False, - "torch_dtype": "bfloat16", - "eos_token_id": eos_id, - } - if style == "it": - # IT model expanded the embedding to 65540 during fine-tuning. - config["vocab_size"] = max(config["vocab_size"], BASE_VOCAB_SIZE + 4) - - (out_dir / "config.json").write_text(json.dumps(config, indent=2)) - - -def _bytes_to_unicode() -> dict[int, str]: - """GPT-2 byte→unicode mapping. Reversibly maps every byte to a printable - Unicode codepoint so byte-level BPE can be expressed as a string-keyed - vocab + merges file.""" - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(256): - if b not in bs: - bs.append(b) - cs.append(256 + n) - n += 1 - return dict(zip(bs, [chr(c) for c in cs])) - - -def _bpe_split(token: bytes, ranks: dict[bytes, int]) -> tuple[bytes, bytes]: - """Recover the (left, right) merge that produced `token` by replaying the - BPE algorithm with all ranks strictly less than `ranks[token]`. - Standard recipe — same one transformers' TikTokenConverter uses.""" - parts = [bytes([b]) for b in token] - target = ranks[token] - while True: - min_idx, min_rank = None, None - for i in range(len(parts) - 1): - r = ranks.get(parts[i] + parts[i + 1]) - if r is not None and r < target and (min_rank is None or r < min_rank): - min_idx, min_rank = i, r - if min_idx is None: - break - parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] - if len(parts) != 2: - raise ValueError(f"could not reduce token {token!r} (rank {target}) to a 2-part merge") - return parts[0], parts[1] - - -def write_tokenizer(vocab_path: Path, out_dir: Path, style: str) -> None: - """Convert tiktoken vocab.txt → HF tokenizer.json (byte-level BPE).""" - ranks = load_tiktoken_bpe(str(vocab_path)) - # Talkie drops the highest rank (reserved for <|endoftext|>). - ranks = {tok: r for tok, r in ranks.items() if r < BASE_VOCAB_SIZE - 1} - - b2u = _bytes_to_unicode() - - def encode(b: bytes) -> str: - return "".join(b2u[x] for x in b) - - vocab = {encode(tok): r for tok, r in ranks.items()} - merges = [] - for tok, _ in sorted(ranks.items(), key=lambda kv: kv[1]): - if len(tok) == 1: - continue - left, right = _bpe_split(tok, ranks) - merges.append((encode(left), encode(right))) - - tk = Tokenizer(models.BPE(vocab=vocab, merges=merges, fuse_unk=False, byte_fallback=False)) - tk.pre_tokenizer = pre_tokenizers.Sequence([ - pre_tokenizers.Split(pattern=Regex(PAT_STR), behavior="isolated"), - pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), - ]) - tk.decoder = decoders.ByteLevel() - - specials = IT_SPECIAL_TOKENS if style == "it" else BASE_SPECIAL_TOKENS - tk.add_special_tokens(list(specials.keys())) - - tk.save(str(out_dir / "tokenizer.json")) - - tcfg = { - "model_max_length": 2048, - "bos_token": None, - "eos_token": "<|end|>" if style == "it" else "<|endoftext|>", - "pad_token": None, - "added_tokens_decoder": { - str(idx): {"content": tok, "special": True} - for tok, idx in specials.items() - }, - } - if style == "it": - tcfg["chat_template"] = CHAT_TEMPLATE - (out_dir / "tokenizer_config.json").write_text(json.dumps(tcfg, indent=2)) - - -def verify_tokenizer(vocab_path: Path, out_dir: Path, style: str) -> None: - """Round-trip a sample through both tiktoken and the converted HF tokenizer. - Failures here mean the BPE reconstruction or pre-tokenizer regex diverged — - do not ship without this passing.""" - import tiktoken - - ranks = load_tiktoken_bpe(str(vocab_path)) - ranks = {tok: r for tok, r in ranks.items() if r < BASE_VOCAB_SIZE - 1} - specials = IT_SPECIAL_TOKENS if style == "it" else BASE_SPECIAL_TOKENS - tt = tiktoken.Encoding( - name="talkie", pat_str=PAT_STR, mergeable_ranks=ranks, special_tokens=specials - ) - hf = Tokenizer.from_file(str(out_dir / "tokenizer.json")) - - samples = [ - "Hello, world!", - "The quick brown fox jumps over the lazy dog.", - "It's a fine day in 1930.\nLet us discuss aeronautics.", - " leading spaces and\ttabs\n\nand newlines", - "café naïve résumé — em-dash and ellipsis…", - ] - if style == "it": - samples.append("<|user|>hi<|end|><|assistant|>") - - for s in samples: - a = tt.encode(s, allowed_special=set(specials)) - b = hf.encode(s).ids - if a != b: - raise SystemExit( - f"tokenizer mismatch on {s!r}\n tiktoken: {a}\n hf: {b}" - ) - print(f"tokenizer round-trip ok ({len(samples)} samples)") - - -def main() -> None: - ap = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) - ap.add_argument("--ckpt", type=Path, required=True) - ap.add_argument("--vocab", type=Path, required=True) - ap.add_argument("--out", type=Path, required=True) - ap.add_argument("--style", choices=["base", "it"], required=True) - ap.add_argument("--dtype", choices=DTYPES.keys(), default="bf16") - args = ap.parse_args() - - args.out.mkdir(parents=True, exist_ok=True) - - print("converting weights…") - meta = convert_weights(args.ckpt, args.out / "model.safetensors", DTYPES[args.dtype]) - - print("writing config.json…") - write_config(args.out, meta, args.style) - - print("writing tokenizer.json + tokenizer_config.json…") - write_tokenizer(args.vocab, args.out, args.style) - - print("verifying tokenizer round-trip…") - verify_tokenizer(args.vocab, args.out, args.style) - - print(f"done → {args.out}") - - -if __name__ == "__main__": - main() diff --git a/catgrad-llm/scripts/llm_talkie.py b/catgrad-llm/scripts/llm_talkie.py deleted file mode 100644 index c76a1c48..00000000 --- a/catgrad-llm/scripts/llm_talkie.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Greedy-decode reference for Talkie. Mirrors `scripts/llm.py`'s -`do_sample=False` path but uses the talkie package (talkie isn't in -HF transformers, so AutoModelForCausalLM doesn't apply). - -Output is true argmax (no Gumbel sampling, no temperature) so the -result is byte-identical across runs and directly comparable to -catgrad-llm's argmax-greedy decode. - -Two loader paths: - - * `--ckpt` — talkie's native `final.ckpt`. Matches upstream - load semantics, but `torch.load` materialises the - whole 53 GB fp32 dict in RAM before casting to - bf16 (transient peak ≈ 100 GB). - * `--safetensors` — pre-converted bf16 safetensors (output of - `convert_talkie.py`). Peak ≈ 52 GB. - -Both produce bit-identical model state once cast, so use safetensors -for the stability harness. - -Two run modes: - - * single-shot — `--prompt P --seq-len N`, output goes to `--out` or stdout. - * batch — `--batch`, reads JSON-line cases from stdin - (`{"prompt": "...", "seq_len": N, "out": "/path"}`), - loads the model once, runs all cases. The matrix harness - uses this so we don't pay the 50 GB load N times. -""" - -from __future__ import annotations - -import argparse -import json -import sys -from pathlib import Path - -import torch - -from talkie.model import GPTConfig, TalkieModel, resize_model_embeddings -from talkie.tokenizer import IT_VOCAB_SIZE, build_tokenizer - - -DTYPES = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} - - -def load_from_safetensors( - path: Path, device: torch.device, target_vocab_size: int | None -) -> TalkieModel: - """Build a TalkieModel from our pre-converted bf16 safetensors. - - Skips the ckpt → fp32-dict → cast round-trip in talkie's own loader, - which would otherwise spike to ~100 GB on a 13B model. - """ - from safetensors.torch import load_file - - sd = load_file(str(path), device="cpu") - vocab_size = sd["embed.weight"].shape[0] - config = GPTConfig(vocab_size=vocab_size) - - cpu = torch.device("cpu") - model = TalkieModel(config, cpu) - model.load_state_dict(sd, strict=True) - del sd - - if target_vocab_size is not None and vocab_size < target_vocab_size: - model = resize_model_embeddings(model, target_vocab_size, cpu) - - model = model.to(dtype=torch.bfloat16).to(device) - model.device = device - model.eval() - return model - - -def load_model(args, device: torch.device) -> TalkieModel: - target_vocab = IT_VOCAB_SIZE if args.style == "it" else None - if args.safetensors is not None: - model = load_from_safetensors(args.safetensors, device, target_vocab) - else: - from talkie.model import load_checkpoint - model = load_checkpoint(str(args.ckpt), device, target_vocab_size=target_vocab) - if DTYPES[args.dtype] != torch.bfloat16: - model = model.to(dtype=DTYPES[args.dtype]) - model.eval() - return model - - -def generate(model: TalkieModel, tokenizer, prompt: str, seq_len: int) -> str: - ids = tokenizer.encode(prompt, allowed_special="all") - x = torch.tensor([ids], device=model.device, dtype=torch.long) - with torch.no_grad(): - for _ in range(seq_len): - logits = model.forward(x) - next_id = int(torch.argmax(logits, dim=-1).item()) - x = torch.cat( - [x, torch.tensor([[next_id]], device=model.device, dtype=torch.long)], - dim=1, - ) - return tokenizer.decode(x[0].tolist()) - - -def main() -> None: - ap = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) - src = ap.add_mutually_exclusive_group(required=True) - src.add_argument("--ckpt", type=Path, help="talkie native final.ckpt (slow load)") - src.add_argument("--safetensors", type=Path, help="bf16 safetensors (fast load)") - ap.add_argument("--vocab", type=Path, required=True) - ap.add_argument("--style", choices=["base", "it"], default="base") - ap.add_argument("--dtype", choices=DTYPES.keys(), default="bf16") - ap.add_argument("--device", default=None, - help="Defaults to cuda if available, otherwise cpu.") - ap.add_argument("-p", "--prompt", help="single-shot prompt") - ap.add_argument("-s", "--seq-len", type=int, default=20) - ap.add_argument("-o", "--out", type=Path, help="single-shot output file (default stdout)") - ap.add_argument("--batch", action="store_true", - help="Read JSONL cases from stdin (one {prompt, seq_len, out} per line).") - args = ap.parse_args() - - if args.batch and args.prompt is not None: - ap.error("--batch and --prompt are mutually exclusive") - if not args.batch and args.prompt is None: - ap.error("either --batch or --prompt is required") - - device = torch.device( - args.device or ("cuda" if torch.cuda.is_available() else "cpu") - ) - print(f"loading talkie on {device} dtype={args.dtype}…", file=sys.stderr, flush=True) - tokenizer = build_tokenizer(str(args.vocab), style=args.style) - model = load_model(args, device) - print("loaded.", file=sys.stderr, flush=True) - - if args.batch: - for raw in sys.stdin: - line = raw.strip() - if not line: - continue - case = json.loads(line) - text = generate(model, tokenizer, case["prompt"], int(case["seq_len"])) - Path(case["out"]).write_text(text + "\n") - print(f" case {case['prompt']!r} s={case['seq_len']} → {case['out']}", - file=sys.stderr, flush=True) - else: - text = generate(model, tokenizer, args.prompt, args.seq_len) - if args.out: - args.out.write_text(text + "\n") - else: - print(text) - - -if __name__ == "__main__": - main() diff --git a/catgrad-llm/src/models/talkie.rs b/catgrad-llm/src/models/talkie.rs index 35b8d7e0..3d78d707 100644 --- a/catgrad-llm/src/models/talkie.rs +++ b/catgrad-llm/src/models/talkie.rs @@ -17,6 +17,12 @@ //! RoPE convention differs from catgrad's default by a sign on `sin` //! (talkie: `y1 = x1·cos + x2·sin`, catgrad: `y1 = x1·cos - x2·sin`); we //! negate `cache.sin` once after init to match. +//! +//! Parameter naming follows the HF-style port at +//! `lewtun/talkie-1930-13b-it-hf` — the decoder stack lives under +//! `model.{embed,blocks.…}` while `lm_head` and `lm_head_gain.w_g` are at +//! the root (matching `TalkieForCausalLM` having `self.model = TalkieModel(…)` +//! and the head as direct attributes). #![allow(clippy::too_many_arguments)] @@ -111,9 +117,13 @@ impl TalkieModel { max_positions: Var, ) -> Vec { let eps = self.config.rms_norm_eps; + // The HF-style port wraps the decoder stack in a `TalkieModel` that + // sits under the `TalkieForCausalLM`'s `self.model`; lm_head and + // lm_head_gain stay at the root. + let m = p.extend(["model"]).unwrap(); // Embed → input RMSNorm (unweighted) → save as e_x for embed-skip. - let x = embeddings(builder, p.extend(["embed"]).unwrap(), x); + let x = embeddings(builder, m.extend(["embed"]).unwrap(), x); let x = rmsnorm_raw::<3>(builder, eps, x); let e_x = x.clone(); @@ -144,7 +154,7 @@ impl TalkieModel { &mut cache, pos.clone(), e_x.clone(), - p.extend(["blocks", &i.to_string()]).unwrap(), + m.extend(["blocks", &i.to_string()]).unwrap(), x, ); }