diff --git a/CLAUDE.md b/CLAUDE.md index 69f4334..571d305 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,52 +1,108 @@ -# CLAUDE.md +# go-rocm — AMD ROCm GPU Inference ## What This Is -AMD ROCm GPU inference for Linux. Module: `forge.lthn.ai/core/go-rocm` +AMD ROCm GPU inference for Linux via managed `llama-server` subprocess. Module: `dappco.re/go/rocm`. -Implements `inference.Backend` and `inference.TextModel` (from `core/go-inference`) using llama.cpp compiled with HIP/ROCm. Targets AMD RDNA 3+ GPUs. +Implements `inference.Backend` and `inference.TextModel` (from `core/go-inference`) using llama.cpp compiled with `-DGGML_HIP=ON`. Targets AMD RDNA 2+ GPUs (tested on Radeon RX 7800 XT, gfx1100). -## Target Hardware +Sibling to `go-mlx` (Metal on macOS). Both expose the same interface; users select at runtime based on `Available()`. -- **GPU**: AMD Radeon RX 7800 XT (gfx1100, RDNA 3, 16 GB VRAM) — confirmed gfx1100, not gfx1101 -- **OS**: Ubuntu 24.04 LTS (linux/amd64) -- **ROCm**: 7.2.0 installed -- **Kernel**: 6.17.0 +## Key Facts -## Commands +- **Subprocess model:** llama-server runs as isolated process, communicates via HTTP/SSE +- **GGUF parser:** Reads model metadata (v2/v3) without loading tensors — enables fast discovery +- **VRAM monitoring:** sysfs-based (no ROCm runtime library dependency) +- **iGPU masking:** `HIP_VISIBLE_DEVICES=0` hardcoded — Ryzen 9 iGPU crashes llama-server if exposed +- **Auto-register:** `init()` registers backend into `inference.Register()` on linux && amd64 +- **Platform stubs:** Exports no-op funcs on non-Linux/amd64 to avoid build failures +- **Error wrapping:** All errors use `coreerr.E(scope, msg, cause)` from `go-log` -```bash -go test ./... # Unit tests (no GPU required) -go test -tags rocm ./... # Integration tests + benchmarks (GPU required) -go test -tags rocm -v -run TestROCm ./... # Full GPU tests only -go test -tags rocm -bench=. -benchtime=3x ./... # Benchmarks -``` +## Hardware & OS -## Architecture +| Component | Value | +|-----------|-------| +| GPU | Radeon RX 7800 XT (gfx1100, RDNA 3, 16 GB) | +| CPU | Ryzen 9 9950X | +| OS | Ubuntu 24.04 LTS | +| ROCm | 7.2.0 | +| Kernel | 6.17.0 | -See `docs/architecture.md` for full detail. +## Architecture ``` -go-rocm/ -├── backend.go inference.Backend (linux && amd64) -├── model.go inference.TextModel (linux && amd64) -├── server.go llama-server subprocess lifecycle -├── vram.go VRAM monitoring via sysfs -├── discover.go GGUF model discovery -├── register_rocm.go auto-registers via init() (linux && amd64) -├── rocm_stub.go stubs for non-linux/non-amd64 -└── internal/ - ├── llamacpp/ llama-server HTTP client + health check - └── gguf/ GGUF v2/v3 binary metadata parser +dappco.re/go/rocm/ +├── Public: +│ ├── rocm.go [VRAMInfo, ModelInfo types] +│ ├── discover.go [DiscoverModels(dir) -> []ModelInfo] +│ ├── register_rocm.go [init() register] +│ +├── Backend/Model (linux && amd64): +│ ├── backend.go [rocmBackend impl] +│ ├── model.go [rocmModel impl, metrics, streaming] +│ ├── server.go [subprocess lifecycle, port mgmt] +│ ├── vram.go [GetVRAMInfo() via sysfs] +│ ├── rocm_stub.go [stubs for other platforms] +│ +└── Internal: + ├── internal/gguf/ + │ └── gguf.go [GGUF v2/v3 binary header parser] + │ + └── internal/llamacpp/ + ├── client.go [HTTP client, Complete, ChatComplete] + └── health.go [/health endpoint polling] ``` -## Critical: iGPU Crash +## Critical Rules + +1. **iGPU always masked:** `serverEnv()` enforces `HIP_VISIBLE_DEVICES=0`. This is non-negotiable. Do not accept as config or env var override. + +2. **Platform-specific:** Build tags `linux && amd64` for GPU code. Stubs on other platforms prevent build errors. + +3. **Subprocess isolation:** llama-server is not trusted. Runs at default perms, minimal env, auto-killed on exit. -The Ryzen 9 9950X iGPU appears as ROCm Device 1. llama-server crashes trying to split tensors across it. `serverEnv()` always sets `HIP_VISIBLE_DEVICES=0`. Do not remove or weaken this. +4. **Error scope:** All errors use `coreerr.E()`. No `fmt.Errorf`, no `errors.New`, no `log` package. -## Building llama-server with ROCm +5. **Banned imports:** `fmt`, `log`, `errors`, `os/exec` use their core.* equivalents. (Note: `os` used directly for file/env ops, justified by GPU module weight constraints.) + +6. **Metrics best-effort:** VRAM stats read non-atomically from sysfs. Under heavy churn, transient gaps expected. Recording is not real-time. + +## Spec Index + +See `/sessions/vibrant-sharp-fermat/mnt/plans/code/core/go/rocm/RFC.md`: + +- **§1–2:** Overview & package layout +- **§3:** Type definitions (VRAMInfo, ModelInfo, rocmBackend, rocmModel, server) +- **§4:** Inference pipeline (Load, Generate, Chat, metrics) +- **§5:** GGUF parser internals +- **§6:** llama-server HTTP bridge +- **§7–9:** VRAM discovery, model discovery, platform support +- **§10–16:** Error handling, config, quantisation, design notes, cross-refs + +## Working Commands ```bash +# Unit tests (no GPU required) +go test ./... + +# Integration tests + benchmarks (GPU required, gfx1100) +go test -tags rocm ./... + +# Full GPU tests only +go test -tags rocm -v -run TestROCm ./... + +# Benchmarks +go test -tags rocm -bench=. -benchtime=3x ./... + +# Format & lint +go fmt ./... +``` + +## Building llama-server + +```bash +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp cmake -B build \ -DGGML_HIP=ON \ -DAMDGPU_TARGETS=gfx1100 \ @@ -56,34 +112,26 @@ cmake --build build --parallel $(nproc) -t llama-server sudo cp build/bin/llama-server /usr/local/bin/llama-server ``` -## Environment Variables +## Coordination -| Variable | Default | Purpose | -|----------|---------|---------| -| `ROCM_LLAMA_SERVER_PATH` | PATH lookup | Path to llama-server binary | -| `HIP_VISIBLE_DEVICES` | overridden to `0` | Always forced to 0 — do not rely on ambient value | +- **Virgil** (forge.lthn.ai/core) — orchestrator, task writer, PR reviewer +- **go-mlx** — sibling Metal backend (same interface contract) +- **go-inference** — shared TextModel/Backend interface definitions +- **go-ml** — scoring engine wrapping both backends +- **LEM training** — uses go-rocm for model eval on Charon homelab -## Coding Standards +## Test Naming -- UK English -- Tests: testify assert/require -- Build tags: `linux && amd64` for GPU code, `rocm` for integration tests -- Errors: `coreerr.E("pkg.Func", "what failed", err)` via `go-log`, never `fmt.Errorf` or `errors.New` -- File I/O: `os` package used directly — `go-io` not imported (its transitive deps are too heavy for a GPU inference module) -- Conventional commits -- Co-Author: `Co-Authored-By: Virgil ` -- Licence: EUPL-1.2 +Format: `TestFilename_Function_{Good,Bad,Ugly}` — all three categories mandatory. -## Coordination +Example: `TestModel_Generate_Good`, `TestModel_Generate_Bad`, `TestModel_Generate_Ugly`. -- **Virgil** (core/go) is the orchestrator — writes tasks and reviews PRs -- **go-mlx** is the sibling — Metal backend on macOS, same interface contract -- **go-inference** defines the shared TextModel/Backend interfaces both backends implement -- **go-ml** wraps both backends into the scoring engine +## Commit Style -## Documentation +``` +type(scope): description + +Co-Authored-By: Virgil +``` -- `docs/architecture.md` — component design, data flow, interface contracts -- `docs/development.md` — prerequisites, test commands, benchmarks, coding standards -- `docs/history.md` — completed phases, commit hashes, known limitations -- `docs/plans/` — phase design documents (read-only reference) +Example: `feat(rocm): add VRAM monitoring via sysfs` diff --git a/README.md b/README.md index 912c6f0..6f51f87 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ AMD ROCm GPU inference for Linux via a managed llama-server subprocess. Implements the `inference.Backend` and `inference.TextModel` interfaces from go-inference for AMD RDNA 3+ GPUs (validated on RX 7800 XT with ROCm 7.2). Uses llama-server's OpenAI-compatible streaming API rather than direct HIP CGO bindings, giving access to 50+ GGUF model architectures with GPU crash isolation. Includes a GGUF v2/v3 binary metadata parser, sysfs VRAM monitoring, and model discovery. Platform-restricted: `linux/amd64` only; a safe stub compiles everywhere else. -**Module**: `forge.lthn.ai/core/go-rocm` +**Module**: `dappco.re/go/rocm` **Licence**: EUPL-1.2 **Language**: Go 1.25 @@ -11,7 +11,7 @@ AMD ROCm GPU inference for Linux via a managed llama-server subprocess. Implemen ```go import ( "forge.lthn.ai/core/go-inference" - _ "forge.lthn.ai/core/go-rocm" // registers "rocm" backend via init() + _ "dappco.re/go/rocm" // registers "rocm" backend via init() ) // Requires llama-server compiled with HIP/ROCm on PATH diff --git a/backend.go b/backend.go index 8f1bdcf..d08e700 100644 --- a/backend.go +++ b/backend.go @@ -6,14 +6,16 @@ import ( "os" "strings" - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-inference" - "forge.lthn.ai/core/go-rocm/internal/gguf" + "dappco.re/go/inference" + coreerr "dappco.re/go/log" + "dappco.re/go/rocm/internal/gguf" ) // rocmBackend implements inference.Backend for AMD ROCm GPUs. type rocmBackend struct{} +const defaultContextLengthCap = 4096 + func (b *rocmBackend) Name() string { return "rocm" } // Available reports whether ROCm GPU inference can run on this machine. @@ -30,68 +32,83 @@ func (b *rocmBackend) Available() bool { // LoadModel loads a GGUF model onto the AMD GPU via llama-server. // Model architecture is read from GGUF metadata (replacing filename-based guessing). -// If no context length is specified, defaults to min(model_context_length, 4096) -// to prevent VRAM exhaustion on models with 128K+ native context. +// If no context length is specified, defaults to min(model_context_length, +// 4096). When metadata omits the native context, it falls back to 4096 to +// keep the load path on the safe side of VRAM usage. func (b *rocmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) { - cfg := inference.ApplyLoadOpts(opts) + loadConfig := inference.ApplyLoadOpts(opts) binary, err := findLlamaServer() if err != nil { return nil, err } - meta, err := gguf.ReadMetadata(path) + metadata, err := gguf.ReadMetadata(path) if err != nil { return nil, coreerr.E("rocm.LoadModel", "read model metadata", err) } - ctxLen := cfg.ContextLen - if ctxLen == 0 && meta.ContextLength > 0 { - ctxLen = int(min(meta.ContextLength, 4096)) - } + contextLength := resolveContextLength(loadConfig.ContextLen, metadata) - srv, err := startServer(binary, path, cfg.GPULayers, ctxLen, cfg.ParallelSlots) + modelServer, err := startServer(serverStartConfig{ + BinaryPath: binary, + ModelPath: path, + GPULayerCount: loadConfig.GPULayers, + ContextSize: contextLength, + ParallelSlotCount: loadConfig.ParallelSlots, + }) if err != nil { return nil, err } - // Map quantisation file type to bit width. - quantBits := 0 - quantGroup := 0 - ftName := gguf.FileTypeName(meta.FileType) - switch { - case strings.HasPrefix(ftName, "Q4_"): - quantBits = 4 - quantGroup = 32 - case strings.HasPrefix(ftName, "Q5_"): - quantBits = 5 - quantGroup = 32 - case strings.HasPrefix(ftName, "Q8_"): - quantBits = 8 - quantGroup = 32 - case strings.HasPrefix(ftName, "Q2_"): - quantBits = 2 - quantGroup = 16 - case strings.HasPrefix(ftName, "Q3_"): - quantBits = 3 - quantGroup = 32 - case strings.HasPrefix(ftName, "Q6_"): - quantBits = 6 - quantGroup = 64 - case ftName == "F16": - quantBits = 16 - case ftName == "F32": - quantBits = 32 - } - return &rocmModel{ - srv: srv, - modelType: meta.Architecture, - modelInfo: inference.ModelInfo{ - Architecture: meta.Architecture, - NumLayers: int(meta.BlockCount), - QuantBits: quantBits, - QuantGroup: quantGroup, - }, + server: modelServer, + modelType: metadata.Architecture, + modelInfo: modelInfoFromMetadata(metadata), }, nil } + +func resolveContextLength(requestedContextLength int, metadata gguf.Metadata) int { + if requestedContextLength > 0 { + return requestedContextLength + } + if metadata.ContextLength == 0 { + return defaultContextLengthCap + } + return min(int(metadata.ContextLength), defaultContextLengthCap) +} + +func modelInfoFromMetadata(metadata gguf.Metadata) inference.ModelInfo { + quantBits, quantGroup := quantisationFromFileType(metadata.FileType) + return inference.ModelInfo{ + Architecture: metadata.Architecture, + NumLayers: int(metadata.BlockCount), + QuantBits: quantBits, + QuantGroup: quantGroup, + } +} + +func quantisationFromFileType(fileType uint32) (bits, groupSize int) { + fileTypeName := gguf.FileTypeName(fileType) + + switch { + case strings.HasPrefix(fileTypeName, "Q4_"): + return 4, 32 + case strings.HasPrefix(fileTypeName, "Q5_"): + return 5, 32 + case strings.HasPrefix(fileTypeName, "Q8_"): + return 8, 32 + case strings.HasPrefix(fileTypeName, "Q2_"): + return 2, 16 + case strings.HasPrefix(fileTypeName, "Q3_"): + return 3, 32 + case strings.HasPrefix(fileTypeName, "Q6_"): + return 6, 64 + case fileTypeName == "F16": + return 16, 0 + case fileTypeName == "F32": + return 32, 0 + default: + return 0, 0 + } +} diff --git a/backend_ax7_test.go b/backend_ax7_test.go new file mode 100644 index 0000000..7c62aaf --- /dev/null +++ b/backend_ax7_test.go @@ -0,0 +1,173 @@ +//go:build linux && amd64 + +package rocm + +import ( + "net/http" + "os" + "path/filepath" + "testing" + + "dappco.re/go/inference" +) + +func TestMain(m *testing.M) { + if os.Getenv("ROCM_FAKE_LLAMA_SERVER") == "1" { + runFakeLlamaServer() + return + } + os.Exit(m.Run()) +} + +func runFakeLlamaServer() { + port := "" + for i, arg := range os.Args { + if arg == "--port" && i+1 < len(os.Args) { + port = os.Args[i+1] + } + } + if port == "" { + os.Exit(2) + } + + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"ok"}`)) + }) + mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { + writeSSEEvent(w, `{"choices":[{"text":"ok","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + }) + mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + writeSSEEvent(w, `{"choices":[{"delta":{"content":"ok"},"finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + }) + + if err := http.ListenAndServe("127.0.0.1:"+port, mux); err != nil { + os.Exit(3) + } +} + +func useFakeLlamaServer(t *testing.T) { + t.Helper() + t.Setenv("ROCM_FAKE_LLAMA_SERVER", "1") + t.Setenv("ROCM_LLAMA_SERVER_PATH", os.Args[0]) +} + +func TestBackend_Backend_Name_Good(t *testing.T) { + name := (&rocmBackend{}).Name() + if name != "rocm" { + t.Fatalf("Name() = %q, want rocm", name) + } + if len(name) == 0 { + t.Fatal("Name() returned an empty backend name") + } +} + +func TestBackend_Backend_Name_Bad(t *testing.T) { + var backend *rocmBackend + name := backend.Name() + if name != "rocm" { + t.Fatalf("nil receiver Name() = %q, want rocm", name) + } + if name == "cpu" { + t.Fatal("Name() returned the wrong backend family") + } +} + +func TestBackend_Backend_Name_Ugly(t *testing.T) { + backend := &rocmBackend{} + first := backend.Name() + second := backend.Name() + if first != second { + t.Fatalf("Name() changed from %q to %q", first, second) + } +} + +func TestBackend_Backend_Available_Good(t *testing.T) { + backend := &rocmBackend{} + available := backend.Available() + if available { + if _, err := os.Stat("/dev/kfd"); err != nil { + t.Fatalf("Available() = true but /dev/kfd stat failed: %v", err) + } + } + if available != backend.Available() { + t.Fatal("Available() changed between adjacent calls") + } +} + +func TestBackend_Backend_Available_Bad(t *testing.T) { + t.Setenv("ROCM_LLAMA_SERVER_PATH", filepath.Join(t.TempDir(), "missing-llama-server")) + available := (&rocmBackend{}).Available() + if available { + t.Fatal("Available() = true with missing llama-server override") + } + if ROCmAvailable() != true { + t.Fatal("ROCmAvailable() should still report compiled support") + } +} + +func TestBackend_Backend_Available_Ugly(t *testing.T) { + dir := t.TempDir() + t.Setenv("ROCM_LLAMA_SERVER_PATH", dir) + available := (&rocmBackend{}).Available() + if available { + t.Fatal("Available() = true with directory llama-server override") + } + if _, err := validateLlamaServerPath(dir); err == nil { + t.Fatal("validateLlamaServerPath(directory) error = nil, want error") + } +} + +func TestBackend_Backend_LoadModel_Good(t *testing.T) { + useFakeLlamaServer(t) + dir := t.TempDir() + modelPath := writeDiscoverTestGGUF(t, dir, "model.gguf", [][2]any{ + {"general.architecture", "llama"}, + {"general.name", "AX Model"}, + {"general.file_type", uint32(15)}, + {"llama.context_length", uint32(4096)}, + {"llama.block_count", uint32(32)}, + }) + + model, err := (&rocmBackend{}).LoadModel(modelPath, inference.WithContextLen(1024)) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer model.Close() + if model.ModelType() != "llama" { + t.Fatalf("ModelType() = %q, want llama", model.ModelType()) + } + if model.Info().NumLayers != 32 { + t.Fatalf("Info().NumLayers = %d, want 32", model.Info().NumLayers) + } +} + +func TestBackend_Backend_LoadModel_Bad(t *testing.T) { + t.Setenv("ROCM_LLAMA_SERVER_PATH", filepath.Join(t.TempDir(), "missing-llama-server")) + model, err := (&rocmBackend{}).LoadModel("/missing/model.gguf") + if err == nil { + t.Fatal("LoadModel() error = nil, want missing llama-server error") + } + if model != nil { + t.Fatalf("model = %v, want nil", model) + } +} + +func TestBackend_Backend_LoadModel_Ugly(t *testing.T) { + useFakeLlamaServer(t) + path := filepath.Join(t.TempDir(), "corrupt.gguf") + if err := os.WriteFile(path, []byte("not gguf"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + model, err := (&rocmBackend{}).LoadModel(path) + if err == nil { + t.Fatal("LoadModel() error = nil, want metadata error") + } + if model != nil { + t.Fatalf("model = %v, want nil", model) + } +} diff --git a/backend_test.go b/backend_test.go new file mode 100644 index 0000000..e3ebfab --- /dev/null +++ b/backend_test.go @@ -0,0 +1,78 @@ +//go:build linux && amd64 + +package rocm + +import ( + "testing" + + "dappco.re/go/rocm/internal/gguf" +) + +func TestBackend_ResolveContextLength_Good(t *testing.T) { + if got := resolveContextLength(2048, gguf.Metadata{ContextLength: 32768}); got != 2048 { + t.Errorf("resolveContextLength(2048, 32768) = %d, want 2048", got) + } + if got := resolveContextLength(0, gguf.Metadata{ContextLength: 1024}); got != 1024 { + t.Errorf("resolveContextLength(0, 1024) = %d, want 1024", got) + } + if got := resolveContextLength(0, gguf.Metadata{ContextLength: 131072}); got != defaultContextLengthCap { + t.Errorf("resolveContextLength(0, 131072) = %d, want %d", got, defaultContextLengthCap) + } +} + +func TestBackend_ResolveContextLength_Ugly(t *testing.T) { + if got := resolveContextLength(0, gguf.Metadata{}); got != defaultContextLengthCap { + t.Errorf("resolveContextLength(0, empty) = %d, want %d", got, defaultContextLengthCap) + } +} + +func TestBackend_QuantisationFromFileType_Good(t *testing.T) { + testCases := []struct { + name string + fileType uint32 + expectedBits int + expectedGroupSize int + }{ + {name: "q4", fileType: 15, expectedBits: 4, expectedGroupSize: 32}, + {name: "q5", fileType: 17, expectedBits: 5, expectedGroupSize: 32}, + {name: "q8", fileType: 7, expectedBits: 8, expectedGroupSize: 32}, + {name: "q2", fileType: 10, expectedBits: 2, expectedGroupSize: 16}, + {name: "q6", fileType: 18, expectedBits: 6, expectedGroupSize: 64}, + {name: "f16", fileType: 1, expectedBits: 16, expectedGroupSize: 0}, + {name: "f32", fileType: 0, expectedBits: 32, expectedGroupSize: 0}, + {name: "unknown", fileType: 999, expectedBits: 0, expectedGroupSize: 0}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + bits, groupSize := quantisationFromFileType(testCase.fileType) + if bits != testCase.expectedBits { + t.Errorf("bits = %d, want %d", bits, testCase.expectedBits) + } + if groupSize != testCase.expectedGroupSize { + t.Errorf("groupSize = %d, want %d", groupSize, testCase.expectedGroupSize) + } + }) + } +} + +func TestBackend_ModelInfoFromMetadata_Good(t *testing.T) { + modelInfo := modelInfoFromMetadata(gguf.Metadata{ + Architecture: "gemma3", + BlockCount: 34, + FileType: 15, + }) + + if modelInfo.Architecture != "gemma3" { + t.Errorf("Architecture = %q, want %q", modelInfo.Architecture, "gemma3") + } + if modelInfo.NumLayers != 34 { + t.Errorf("NumLayers = %d, want 34", modelInfo.NumLayers) + } + if modelInfo.QuantBits != 4 { + t.Errorf("QuantBits = %d, want 4", modelInfo.QuantBits) + } + if modelInfo.QuantGroup != 32 { + t.Errorf("QuantGroup = %d, want 32", modelInfo.QuantGroup) + } +} diff --git a/discover.go b/discover.go index 0d298ca..326e3f6 100644 --- a/discover.go +++ b/discover.go @@ -3,15 +3,24 @@ package rocm import ( "path/filepath" - "forge.lthn.ai/core/go-rocm/internal/gguf" + coreerr "dappco.re/go/log" + "dappco.re/go/rocm/internal/gguf" ) -// DiscoverModels scans a directory for GGUF model files and returns -// structured information about each. Files that cannot be parsed are skipped. +// models, err := DiscoverModels("/data/lem/gguf") +// fmt.Println(models[0].Architecture, models[0].Quantisation) +// +// DiscoverModels scans a directory for GGUF model files and returns structured +// information about each. Files that cannot be parsed are skipped. func DiscoverModels(dir string) ([]ModelInfo, error) { - matches, err := filepath.Glob(filepath.Join(dir, "*.gguf")) + root, err := filepath.Abs(dir) if err != nil { - return nil, err + return nil, coreerr.E("rocm.DiscoverModels", "resolve model directory", err) + } + + matches, err := filepath.Glob(filepath.Join(root, "*.gguf")) + if err != nil { + return nil, coreerr.E("rocm.DiscoverModels", "glob gguf files", err) } var models []ModelInfo diff --git a/discover_ax7_test.go b/discover_ax7_test.go new file mode 100644 index 0000000..9ad54de --- /dev/null +++ b/discover_ax7_test.go @@ -0,0 +1,62 @@ +package rocm + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestDiscover_DiscoverModels_Good(t *testing.T) { + dir := t.TempDir() + modelPath := writeDiscoverTestGGUF(t, dir, "agent.gguf", [][2]any{ + {"general.architecture", "llama"}, + {"general.name", "Agent Model"}, + {"general.file_type", uint32(15)}, + {"general.size_label", "8B"}, + {"llama.context_length", uint32(8192)}, + }) + + models, err := DiscoverModels(dir) + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } + if len(models) != 1 { + t.Fatalf("len(models) = %d, want 1", len(models)) + } + if models[0].Path != modelPath { + t.Errorf("models[0].Path = %q, want %q", models[0].Path, modelPath) + } + if models[0].Architecture != "llama" { + t.Errorf("models[0].Architecture = %q, want llama", models[0].Architecture) + } +} + +func TestDiscover_DiscoverModels_Bad(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "corrupt.gguf"), []byte("not a gguf header"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + models, err := DiscoverModels(dir) + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } + if len(models) != 0 { + t.Errorf("len(models) = %d, want 0", len(models)) + } +} + +func TestDiscover_DiscoverModels_Ugly(t *testing.T) { + dir := filepath.Join(t.TempDir(), "bad[") + models, err := DiscoverModels(dir) + if err == nil { + t.Fatal("expected glob error, got nil") + } + if models != nil { + t.Errorf("models = %v, want nil", models) + } + if !strings.Contains(err.Error(), "glob gguf files") { + t.Errorf("err = %v, want contains glob gguf files", err) + } +} diff --git a/discover_test.go b/discover_test.go index 9a6ce1a..e3df239 100644 --- a/discover_test.go +++ b/discover_test.go @@ -4,10 +4,8 @@ import ( "encoding/binary" "os" "path/filepath" + "strings" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // writeDiscoverTestGGUF creates a minimal GGUF v3 file in dir with the given @@ -18,17 +16,27 @@ func writeDiscoverTestGGUF(t *testing.T, dir, filename string, kvs [][2]any) str path := filepath.Join(dir, filename) f, err := os.Create(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("os.Create: %v", err) + } defer f.Close() // Magic: "GGUF" in little-endian - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) + if err := binary.Write(f, binary.LittleEndian, uint32(0x46554747)); err != nil { + t.Fatalf("write magic: %v", err) + } // Version 3 - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(3))) + if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { + t.Fatalf("write version: %v", err) + } // Tensor count (uint64): 0 - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(0))) + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatalf("write tensor count: %v", err) + } // KV count (uint64) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(kvs)))) + if err := binary.Write(f, binary.LittleEndian, uint64(len(kvs))); err != nil { + t.Fatalf("write kv count: %v", err) + } for _, kv := range kvs { key := kv[0].(string) @@ -42,22 +50,34 @@ func writeDiscoverKV(t *testing.T, f *os.File, key string, val any) { t.Helper() // Key: uint64 length + bytes - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(key)))) - _, err := f.Write([]byte(key)) - require.NoError(t, err) + if err := binary.Write(f, binary.LittleEndian, uint64(len(key))); err != nil { + t.Fatalf("write key len: %v", err) + } + if _, err := f.Write([]byte(key)); err != nil { + t.Fatalf("write key bytes: %v", err) + } switch v := val.(type) { case string: // Type: 8 (string) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(8))) + if err := binary.Write(f, binary.LittleEndian, uint32(8)); err != nil { + t.Fatalf("write string type: %v", err) + } // String value: uint64 length + bytes - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(v)))) - _, err := f.Write([]byte(v)) - require.NoError(t, err) + if err := binary.Write(f, binary.LittleEndian, uint64(len(v))); err != nil { + t.Fatalf("write string len: %v", err) + } + if _, err := f.Write([]byte(v)); err != nil { + t.Fatalf("write string bytes: %v", err) + } case uint32: // Type: 4 (uint32) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(4))) - require.NoError(t, binary.Write(f, binary.LittleEndian, v)) + if err := binary.Write(f, binary.LittleEndian, uint32(4)); err != nil { + t.Fatalf("write uint32 type: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, v); err != nil { + t.Fatalf("write uint32 val: %v", err) + } default: t.Fatalf("writeDiscoverKV: unsupported value type %T", val) } @@ -86,46 +106,146 @@ func TestDiscoverModels(t *testing.T) { }) // Create a non-GGUF file that should be ignored (no .gguf extension). - require.NoError(t, os.WriteFile(filepath.Join(dir, "README.txt"), []byte("not a model"), 0644)) + if err := os.WriteFile(filepath.Join(dir, "README.txt"), []byte("not a model"), 0644); err != nil { + t.Fatalf("WriteFile README: %v", err) + } models, err := DiscoverModels(dir) - require.NoError(t, err) - require.Len(t, models, 2) + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } + if len(models) != 2 { + t.Fatalf("len(models) = %d, want 2", len(models)) + } // Sort order from Glob is lexicographic, so gemma3 comes first. gemma := models[0] - assert.Equal(t, filepath.Join(dir, "gemma3-4b-q4km.gguf"), gemma.Path) - assert.Equal(t, "gemma3", gemma.Architecture) - assert.Equal(t, "Gemma 3 4B Instruct", gemma.Name) - assert.Equal(t, "Q4_K_M", gemma.Quantisation) - assert.Equal(t, "4B", gemma.Parameters) - assert.Equal(t, uint32(32768), gemma.ContextLen) - assert.Greater(t, gemma.FileSize, int64(0)) + if got, want := gemma.Path, filepath.Join(dir, "gemma3-4b-q4km.gguf"); got != want { + t.Errorf("gemma.Path = %q, want %q", got, want) + } + if gemma.Architecture != "gemma3" { + t.Errorf("gemma.Architecture = %q, want %q", gemma.Architecture, "gemma3") + } + if gemma.Name != "Gemma 3 4B Instruct" { + t.Errorf("gemma.Name = %q, want %q", gemma.Name, "Gemma 3 4B Instruct") + } + if gemma.Quantisation != "Q4_K_M" { + t.Errorf("gemma.Quantisation = %q, want %q", gemma.Quantisation, "Q4_K_M") + } + if gemma.Parameters != "4B" { + t.Errorf("gemma.Parameters = %q, want %q", gemma.Parameters, "4B") + } + if gemma.ContextLen != uint32(32768) { + t.Errorf("gemma.ContextLen = %d, want 32768", gemma.ContextLen) + } + if gemma.FileSize <= 0 { + t.Errorf("gemma.FileSize = %d, want > 0", gemma.FileSize) + } llama := models[1] - assert.Equal(t, filepath.Join(dir, "llama-3.1-8b-q4km.gguf"), llama.Path) - assert.Equal(t, "llama", llama.Architecture) - assert.Equal(t, "Llama 3.1 8B Instruct", llama.Name) - assert.Equal(t, "Q4_K_M", llama.Quantisation) - assert.Equal(t, "8B", llama.Parameters) - assert.Equal(t, uint32(131072), llama.ContextLen) - assert.Greater(t, llama.FileSize, int64(0)) + if got, want := llama.Path, filepath.Join(dir, "llama-3.1-8b-q4km.gguf"); got != want { + t.Errorf("llama.Path = %q, want %q", got, want) + } + if llama.Architecture != "llama" { + t.Errorf("llama.Architecture = %q, want %q", llama.Architecture, "llama") + } + if llama.Name != "Llama 3.1 8B Instruct" { + t.Errorf("llama.Name = %q, want %q", llama.Name, "Llama 3.1 8B Instruct") + } + if llama.Quantisation != "Q4_K_M" { + t.Errorf("llama.Quantisation = %q, want %q", llama.Quantisation, "Q4_K_M") + } + if llama.Parameters != "8B" { + t.Errorf("llama.Parameters = %q, want %q", llama.Parameters, "8B") + } + if llama.ContextLen != uint32(131072) { + t.Errorf("llama.ContextLen = %d, want 131072", llama.ContextLen) + } + if llama.FileSize <= 0 { + t.Errorf("llama.FileSize = %d, want > 0", llama.FileSize) + } +} + +func TestDiscoverModels_RelativeDirReturnsAbsolutePaths(t *testing.T) { + parent := t.TempDir() + dir := filepath.Join(parent, "models") + if err := os.Mkdir(dir, 0755); err != nil { + t.Fatalf("Mkdir: %v", err) + } + + path := writeDiscoverTestGGUF(t, dir, "model.gguf", [][2]any{ + {"general.architecture", "llama"}, + {"general.name", "Relative Model"}, + {"general.file_type", uint32(15)}, + }) + + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(parent); err != nil { + t.Fatalf("Chdir parent: %v", err) + } + t.Cleanup(func() { + if err := os.Chdir(wd); err != nil { + t.Fatalf("Chdir restore: %v", err) + } + }) + + models, err := DiscoverModels("models") + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } + if len(models) != 1 { + t.Fatalf("len(models) = %d, want 1", len(models)) + } + gotPath, err := filepath.EvalSymlinks(models[0].Path) + if err != nil { + t.Fatalf("EvalSymlinks got path: %v", err) + } + wantPath, err := filepath.EvalSymlinks(path) + if err != nil { + t.Fatalf("EvalSymlinks want path: %v", err) + } + if gotPath != wantPath { + t.Errorf("models[0].Path = %q, want %q", gotPath, wantPath) + } } func TestDiscoverModels_EmptyDir(t *testing.T) { dir := t.TempDir() models, err := DiscoverModels(dir) - require.NoError(t, err) - assert.Empty(t, models) + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } + if len(models) != 0 { + t.Errorf("len(models) = %d, want 0", len(models)) + } } func TestDiscoverModels_NotFound(t *testing.T) { // filepath.Glob returns nil, nil for a pattern matching no files, // even when the directory does not exist. models, err := DiscoverModels("/nonexistent/dir") - require.NoError(t, err) - assert.Empty(t, models) + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } + if len(models) != 0 { + t.Errorf("len(models) = %d, want 0", len(models)) + } +} + +func TestDiscoverModels_BadPattern(t *testing.T) { + dir := filepath.Join(t.TempDir(), "bad[") + + _, err := DiscoverModels(dir) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "glob gguf files") { + t.Errorf("err = %v, want contains %q", err, "glob gguf files") + } } func TestDiscoverModels_SkipsCorruptFile(t *testing.T) { @@ -139,11 +259,19 @@ func TestDiscoverModels_SkipsCorruptFile(t *testing.T) { }) // Create a corrupt .gguf file (not valid GGUF binary). - require.NoError(t, os.WriteFile(filepath.Join(dir, "corrupt.gguf"), []byte("not gguf data"), 0644)) + if err := os.WriteFile(filepath.Join(dir, "corrupt.gguf"), []byte("not gguf data"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } models, err := DiscoverModels(dir) - require.NoError(t, err) + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } // Only the valid model should be returned; corrupt one is silently skipped. - require.Len(t, models, 1) - assert.Equal(t, "Valid Model", models[0].Name) + if len(models) != 1 { + t.Fatalf("len(models) = %d, want 1", len(models)) + } + if models[0].Name != "Valid Model" { + t.Errorf("models[0].Name = %q, want %q", models[0].Name, "Valid Model") + } } diff --git a/docs/architecture.md b/docs/architecture.md index 278a949..ed760e2 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -4,7 +4,7 @@ go-rocm provides AMD ROCm GPU inference for Linux by managing llama-server as a subprocess. It implements the `inference.Backend` and `inference.TextModel` interfaces from go-inference, making the AMD GPU available to the broader Go ML ecosystem (go-ml, go-ai, go-i18n) without any CGO in the package itself. -Module path: `forge.lthn.ai/core/go-rocm` +Module path: `dappco.re/go/rocm` ## Design Choice: Subprocess over CGO @@ -53,7 +53,7 @@ The package uses build constraints to ensure correctness across platforms: On Linux/amd64, `register_rocm.go` calls `inference.Register(&rocmBackend{})` in an `init()` function. Any program that blank-imports go-rocm gets the backend automatically: ```go -import _ "forge.lthn.ai/core/go-rocm" +import _ "dappco.re/go/rocm" ``` The backend is then available to `inference.LoadModel()` from go-inference, which iterates registered backends and calls `Available()` on each to select one. diff --git a/go.mod b/go.mod index 60bb928..761f3c7 100644 --- a/go.mod +++ b/go.mod @@ -1,26 +1,16 @@ -module dappco.re/go/core/rocm +module dappco.re/go/rocm go 1.26.0 require ( - dappco.re/go/core/inference v0.1.5 - dappco.re/go/core/log v0.0.4 + dappco.re/go/log v0.8.0-alpha.1 + dappco.re/go/inference v0.8.0-alpha.1 ) -require github.com/kr/text v0.2.0 // indirect +require dappco.re/go/core v0.8.0-alpha.1 // indirect -require ( - dappco.re/go/core v0.5.0 - dappco.re/go/core/api v0.2.0 - dappco.re/go/core/i18n v0.2.0 - dappco.re/go/core/io v0.2.0 - dappco.re/go/core/log v0.1.0 - dappco.re/go/core/process v0.3.0 - dappco.re/go/core/scm v0.4.0 - dappco.re/go/core/store v0.2.0 - dappco.re/go/core/ws v0.3.0 - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/stretchr/testify v1.11.1 - gopkg.in/yaml.v3 v3.0.1 // indirect -) +replace dappco.re/go/core => ../go + +replace dappco.re/go/log => ../go-log + +replace dappco.re/go/inference => ../go-inference diff --git a/go.sum b/go.sum index f55559e..5d19fde 100644 --- a/go.sum +++ b/go.sum @@ -1,20 +1,8 @@ -forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0= -forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= -github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/gguf/ax7_test.go b/internal/gguf/ax7_test.go new file mode 100644 index 0000000..f93210e --- /dev/null +++ b/internal/gguf/ax7_test.go @@ -0,0 +1,84 @@ +package gguf + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestGGUF_FileTypeName_Good(t *testing.T) { + name := FileTypeName(15) + if name != "Q4_K_M" { + t.Fatalf("FileTypeName(15) = %q, want Q4_K_M", name) + } + if FileTypeName(17) != "Q5_K_M" { + t.Fatalf("FileTypeName(17) = %q, want Q5_K_M", FileTypeName(17)) + } +} + +func TestGGUF_FileTypeName_Bad(t *testing.T) { + name := FileTypeName(999) + if name != "type_999" { + t.Fatalf("FileTypeName(999) = %q, want type_999", name) + } + if !strings.HasPrefix(name, "type_") { + t.Fatalf("unknown file type name = %q, want type_ prefix", name) + } +} + +func TestGGUF_FileTypeName_Ugly(t *testing.T) { + zero := FileTypeName(0) + large := FileTypeName(^uint32(0)) + if zero != "F32" { + t.Fatalf("FileTypeName(0) = %q, want F32", zero) + } + if large != "type_4294967295" { + t.Fatalf("FileTypeName(MaxUint32) = %q, want generated name", large) + } +} + +func TestGGUF_ReadMetadata_Good(t *testing.T) { + path := writeTestGGUFOrdered(t, [][2]any{ + {"general.architecture", "llama"}, + {"general.name", "AX Llama"}, + {"general.file_type", uint32(15)}, + {"general.size_label", "8B"}, + {"llama.context_length", uint32(4096)}, + {"llama.block_count", uint32(32)}, + }) + + metadata, err := ReadMetadata(path) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } + if metadata.Architecture != "llama" || metadata.Name != "AX Llama" { + t.Fatalf("metadata = %+v, want llama metadata", metadata) + } +} + +func TestGGUF_ReadMetadata_Bad(t *testing.T) { + metadata, err := ReadMetadata(filepath.Join(t.TempDir(), "missing.gguf")) + if err == nil { + t.Fatal("ReadMetadata() error = nil, want open error") + } + if metadata != (Metadata{}) { + t.Fatalf("metadata = %+v, want zero value", metadata) + } +} + +func TestGGUF_ReadMetadata_Ugly(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "truncated.gguf") + if err := os.WriteFile(path, []byte("GGUF"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + metadata, err := ReadMetadata(path) + if err == nil { + t.Fatal("ReadMetadata() error = nil, want truncated header error") + } + if metadata != (Metadata{}) { + t.Fatalf("metadata = %+v, want zero value", metadata) + } +} diff --git a/internal/gguf/gguf.go b/internal/gguf/gguf.go index 28a290e..e7be65a 100644 --- a/internal/gguf/gguf.go +++ b/internal/gguf/gguf.go @@ -16,7 +16,7 @@ import ( "os" "strings" - coreerr "forge.lthn.ai/core/go-log" + coreerr "dappco.re/go/log" ) // ggufMagic is the GGUF file magic number: "GGUF" in little-endian. @@ -70,8 +70,10 @@ var fileTypeNames = map[uint32]string{ 18: "Q6_K", } -// FileTypeName returns a human-readable name for a GGML quantisation file type. -// Unknown types return "type_N" where N is the numeric value. +// name := FileTypeName(15) // "Q4_K_M" +// +// FileTypeName returns a human-readable name for a GGML quantisation file +// type. Unknown types return "type_N" where N is the numeric value. func FileTypeName(ft uint32) string { if name, ok := fileTypeNames[ft]; ok { return name @@ -79,25 +81,28 @@ func FileTypeName(ft uint32) string { return fmt.Sprintf("type_%d", ft) } +// metadata, err := ReadMetadata("/models/gemma3-4b.gguf") +// // ReadMetadata reads the GGUF header from the file at path and returns the -// extracted metadata. Only metadata KV pairs are read; tensor data is not loaded. +// extracted metadata. Only metadata KV pairs are read; tensor data is not +// loaded. func ReadMetadata(path string) (Metadata, error) { - f, err := os.Open(path) + file, err := os.Open(path) if err != nil { - return Metadata{}, err + return Metadata{}, coreerr.E("gguf.ReadMetadata", "open file", err) } - defer f.Close() + defer file.Close() - info, err := f.Stat() + fileInfo, err := file.Stat() if err != nil { - return Metadata{}, err + return Metadata{}, coreerr.E("gguf.ReadMetadata", "stat file", err) } - r := bufio.NewReader(f) + reader := bufio.NewReader(file) // Read and validate magic number. var magic uint32 - if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &magic); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading magic", err) } if magic != ggufMagic { @@ -106,7 +111,7 @@ func ReadMetadata(path string) (Metadata, error) { // Read version. var version uint32 - if err := binary.Read(r, binary.LittleEndian, &version); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &version); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading version", err) } if version < 2 || version > 3 { @@ -116,22 +121,22 @@ func ReadMetadata(path string) (Metadata, error) { // Read tensor count and KV count. v3 uses uint64, v2 uses uint32. var tensorCount, kvCount uint64 if version == 3 { - if err := binary.Read(r, binary.LittleEndian, &tensorCount); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &tensorCount); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err) } - if err := binary.Read(r, binary.LittleEndian, &kvCount); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &kvCount); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } } else { - var tc, kc uint32 - if err := binary.Read(r, binary.LittleEndian, &tc); err != nil { + var tensorCount32, kvCount32 uint32 + if err := binary.Read(reader, binary.LittleEndian, &tensorCount32); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading tensor count", err) } - if err := binary.Read(r, binary.LittleEndian, &kc); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &kvCount32); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", "reading kv count", err) } - tensorCount = uint64(tc) - kvCount = uint64(kc) + tensorCount = uint64(tensorCount32) + kvCount = uint64(kvCount32) } _ = tensorCount // we only read metadata KVs @@ -139,7 +144,7 @@ func ReadMetadata(path string) (Metadata, error) { // Architecture-specific keys (e.g. llama.context_length) may appear before // the general.architecture key, so we collect all candidates and resolve after. var meta Metadata - meta.FileSize = info.Size() + meta.FileSize = fileInfo.Size() // candidateContextLength and candidateBlockCount store values keyed by // their full key name (e.g. "llama.context_length") so we can match them @@ -148,75 +153,75 @@ func ReadMetadata(path string) (Metadata, error) { candidateBlockCount := make(map[string]uint32) for i := uint64(0); i < kvCount; i++ { - key, err := readString(r) + key, err := readString(reader) if err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading key %d", i), err) } var valType uint32 - if err := binary.Read(r, binary.LittleEndian, &valType); err != nil { + if err := binary.Read(reader, binary.LittleEndian, &valType); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value type for key %q", key), err) } // Check whether this is an interesting key before reading the value. switch { case key == "general.architecture": - v, err := readTypedValue(r, valType) + value, err := readTypedValue(reader, valType) if err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } - if s, ok := v.(string); ok { + if s, ok := value.(string); ok { meta.Architecture = s } case key == "general.name": - v, err := readTypedValue(r, valType) + value, err := readTypedValue(reader, valType) if err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } - if s, ok := v.(string); ok { + if s, ok := value.(string); ok { meta.Name = s } case key == "general.file_type": - v, err := readTypedValue(r, valType) + value, err := readTypedValue(reader, valType) if err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } - if u, ok := v.(uint32); ok { + if u, ok := value.(uint32); ok { meta.FileType = u } case key == "general.size_label": - v, err := readTypedValue(r, valType) + value, err := readTypedValue(reader, valType) if err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } - if s, ok := v.(string); ok { + if s, ok := value.(string); ok { meta.SizeLabel = s } case strings.HasSuffix(key, ".context_length"): - v, err := readTypedValue(r, valType) + value, err := readTypedValue(reader, valType) if err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } - if u, ok := v.(uint32); ok { + if u, ok := value.(uint32); ok { candidateContextLength[key] = u } case strings.HasSuffix(key, ".block_count"): - v, err := readTypedValue(r, valType) + value, err := readTypedValue(reader, valType) if err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("reading value for key %q", key), err) } - if u, ok := v.(uint32); ok { + if u, ok := value.(uint32); ok { candidateBlockCount[key] = u } default: // Skip uninteresting value. - if err := skipValue(r, valType); err != nil { + if err := skipValue(reader, valType); err != nil { return Metadata{}, coreerr.E("gguf.ReadMetadata", fmt.Sprintf("skipping value for key %q", key), err) } } @@ -287,16 +292,16 @@ func readTypedValue(r io.Reader, valType uint32) (any, error) { func skipValue(r io.Reader, valType uint32) error { switch valType { case typeUint8, typeInt8, typeBool: - _, err := readN(r, 1) + _, err := discardBytes(r, 1) return err case typeUint16, typeInt16: - _, err := readN(r, 2) + _, err := discardBytes(r, 2) return err case typeUint32, typeInt32, typeFloat32: - _, err := readN(r, 4) + _, err := discardBytes(r, 4) return err case typeUint64, typeInt64, typeFloat64: - _, err := readN(r, 8) + _, err := discardBytes(r, 8) return err case typeString: var length uint64 @@ -306,7 +311,7 @@ func skipValue(r io.Reader, valType uint32) error { if length > maxStringLength { return coreerr.E("gguf.skipValue", fmt.Sprintf("string length %d exceeds maximum %d", length, maxStringLength), nil) } - _, err := readN(r, int64(length)) + _, err := discardBytes(r, int64(length)) return err case typeArray: var elemType uint32 @@ -328,7 +333,7 @@ func skipValue(r io.Reader, valType uint32) error { } } -// readN reads and discards exactly n bytes from r. -func readN(r io.Reader, n int64) (int64, error) { +// discardBytes reads and discards exactly n bytes from r. +func discardBytes(r io.Reader, n int64) (int64, error) { return io.CopyN(io.Discard, r, n) } diff --git a/internal/gguf/gguf_test.go b/internal/gguf/gguf_test.go index 5afb7b0..c1c4321 100644 --- a/internal/gguf/gguf_test.go +++ b/internal/gguf/gguf_test.go @@ -4,10 +4,8 @@ import ( "encoding/binary" "os" "path/filepath" + "strings" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // writeTestGGUFOrdered creates a synthetic GGUF v3 file with KV pairs in the @@ -19,17 +17,27 @@ func writeTestGGUFOrdered(t *testing.T, kvs [][2]any) string { path := filepath.Join(dir, "test.gguf") f, err := os.Create(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("os.Create: %v", err) + } defer f.Close() // Magic - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) + if err := binary.Write(f, binary.LittleEndian, uint32(0x46554747)); err != nil { + t.Fatalf("write magic: %v", err) + } // Version 3 - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(3))) + if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { + t.Fatalf("write version: %v", err) + } // Tensor count (uint64): 0 - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(0))) + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatalf("write tensor count: %v", err) + } // KV count (uint64) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(kvs)))) + if err := binary.Write(f, binary.LittleEndian, uint64(len(kvs))); err != nil { + t.Fatalf("write kv count: %v", err) + } for _, kv := range kvs { key := kv[0].(string) @@ -43,26 +51,42 @@ func writeKV(t *testing.T, f *os.File, key string, val any) { t.Helper() // Key: uint64 length + bytes - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(key)))) - _, err := f.Write([]byte(key)) - require.NoError(t, err) + if err := binary.Write(f, binary.LittleEndian, uint64(len(key))); err != nil { + t.Fatalf("write key len: %v", err) + } + if _, err := f.Write([]byte(key)); err != nil { + t.Fatalf("write key bytes: %v", err) + } switch v := val.(type) { case string: // Type: 8 (string) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(8))) + if err := binary.Write(f, binary.LittleEndian, uint32(8)); err != nil { + t.Fatalf("write string type: %v", err) + } // String value: uint64 length + bytes - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(v)))) - _, err := f.Write([]byte(v)) - require.NoError(t, err) + if err := binary.Write(f, binary.LittleEndian, uint64(len(v))); err != nil { + t.Fatalf("write string len: %v", err) + } + if _, err := f.Write([]byte(v)); err != nil { + t.Fatalf("write string bytes: %v", err) + } case uint32: // Type: 4 (uint32) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(4))) - require.NoError(t, binary.Write(f, binary.LittleEndian, v)) + if err := binary.Write(f, binary.LittleEndian, uint32(4)); err != nil { + t.Fatalf("write uint32 type: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, v); err != nil { + t.Fatalf("write uint32 val: %v", err) + } case uint64: // Type: 10 (uint64) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(10))) - require.NoError(t, binary.Write(f, binary.LittleEndian, v)) + if err := binary.Write(f, binary.LittleEndian, uint32(10)); err != nil { + t.Fatalf("write uint64 type: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, v); err != nil { + t.Fatalf("write uint64 val: %v", err) + } default: t.Fatalf("writeKV: unsupported value type %T", val) } @@ -73,12 +97,18 @@ func writeKV(t *testing.T, f *os.File, key string, val any) { func writeRawKV(t *testing.T, f *os.File, key string, valType uint32, rawVal []byte) { t.Helper() - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(len(key)))) - _, err := f.Write([]byte(key)) - require.NoError(t, err) - require.NoError(t, binary.Write(f, binary.LittleEndian, valType)) - _, err = f.Write(rawVal) - require.NoError(t, err) + if err := binary.Write(f, binary.LittleEndian, uint64(len(key))); err != nil { + t.Fatalf("write key len: %v", err) + } + if _, err := f.Write([]byte(key)); err != nil { + t.Fatalf("write key bytes: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, valType); err != nil { + t.Fatalf("write val type: %v", err) + } + if _, err := f.Write(rawVal); err != nil { + t.Fatalf("write raw val: %v", err) + } } // writeTestGGUFV2 creates a synthetic GGUF v2 file (uint32 tensor/kv counts). @@ -89,17 +119,27 @@ func writeTestGGUFV2(t *testing.T, kvs [][2]any) string { path := filepath.Join(dir, "test_v2.gguf") f, err := os.Create(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("os.Create: %v", err) + } defer f.Close() // Magic - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) + if err := binary.Write(f, binary.LittleEndian, uint32(0x46554747)); err != nil { + t.Fatalf("write magic: %v", err) + } // Version 2 - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(2))) + if err := binary.Write(f, binary.LittleEndian, uint32(2)); err != nil { + t.Fatalf("write version: %v", err) + } // Tensor count (uint32 for v2): 0 - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0))) + if err := binary.Write(f, binary.LittleEndian, uint32(0)); err != nil { + t.Fatalf("write tensor count: %v", err) + } // KV count (uint32 for v2) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(len(kvs)))) + if err := binary.Write(f, binary.LittleEndian, uint32(len(kvs))); err != nil { + t.Fatalf("write kv count: %v", err) + } for _, kv := range kvs { key := kv[0].(string) @@ -120,15 +160,31 @@ func TestReadMetadata_Gemma3(t *testing.T) { }) m, err := ReadMetadata(path) - require.NoError(t, err) - - assert.Equal(t, "gemma3", m.Architecture) - assert.Equal(t, "Test Gemma3 1B", m.Name) - assert.Equal(t, uint32(17), m.FileType) - assert.Equal(t, "1B", m.SizeLabel) - assert.Equal(t, uint32(32768), m.ContextLength) - assert.Equal(t, uint32(26), m.BlockCount) - assert.Greater(t, m.FileSize, int64(0)) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } + + if m.Architecture != "gemma3" { + t.Errorf("Architecture = %q, want %q", m.Architecture, "gemma3") + } + if m.Name != "Test Gemma3 1B" { + t.Errorf("Name = %q, want %q", m.Name, "Test Gemma3 1B") + } + if m.FileType != uint32(17) { + t.Errorf("FileType = %d, want 17", m.FileType) + } + if m.SizeLabel != "1B" { + t.Errorf("SizeLabel = %q, want %q", m.SizeLabel, "1B") + } + if m.ContextLength != uint32(32768) { + t.Errorf("ContextLength = %d, want 32768", m.ContextLength) + } + if m.BlockCount != uint32(26) { + t.Errorf("BlockCount = %d, want 26", m.BlockCount) + } + if m.FileSize <= 0 { + t.Errorf("FileSize = %d, want > 0", m.FileSize) + } } func TestReadMetadata_Llama(t *testing.T) { @@ -142,15 +198,31 @@ func TestReadMetadata_Llama(t *testing.T) { }) m, err := ReadMetadata(path) - require.NoError(t, err) - - assert.Equal(t, "llama", m.Architecture) - assert.Equal(t, "Test Llama 8B", m.Name) - assert.Equal(t, uint32(15), m.FileType) - assert.Equal(t, "8B", m.SizeLabel) - assert.Equal(t, uint32(131072), m.ContextLength) - assert.Equal(t, uint32(32), m.BlockCount) - assert.Greater(t, m.FileSize, int64(0)) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } + + if m.Architecture != "llama" { + t.Errorf("Architecture = %q, want %q", m.Architecture, "llama") + } + if m.Name != "Test Llama 8B" { + t.Errorf("Name = %q, want %q", m.Name, "Test Llama 8B") + } + if m.FileType != uint32(15) { + t.Errorf("FileType = %d, want 15", m.FileType) + } + if m.SizeLabel != "8B" { + t.Errorf("SizeLabel = %q, want %q", m.SizeLabel, "8B") + } + if m.ContextLength != uint32(131072) { + t.Errorf("ContextLength = %d, want 131072", m.ContextLength) + } + if m.BlockCount != uint32(32) { + t.Errorf("BlockCount = %d, want 32", m.BlockCount) + } + if m.FileSize <= 0 { + t.Errorf("FileSize = %d, want > 0", m.FileSize) + } } func TestReadMetadata_ArchAfterContextLength(t *testing.T) { @@ -166,37 +238,67 @@ func TestReadMetadata_ArchAfterContextLength(t *testing.T) { }) m, err := ReadMetadata(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } - assert.Equal(t, "llama", m.Architecture) - assert.Equal(t, "Out-of-Order Model", m.Name) - assert.Equal(t, uint32(4096), m.ContextLength) - assert.Equal(t, uint32(32), m.BlockCount) + if m.Architecture != "llama" { + t.Errorf("Architecture = %q, want %q", m.Architecture, "llama") + } + if m.Name != "Out-of-Order Model" { + t.Errorf("Name = %q, want %q", m.Name, "Out-of-Order Model") + } + if m.ContextLength != uint32(4096) { + t.Errorf("ContextLength = %d, want 4096", m.ContextLength) + } + if m.BlockCount != uint32(32) { + t.Errorf("BlockCount = %d, want 32", m.BlockCount) + } } func TestReadMetadata_InvalidMagic(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "notgguf.bin") - err := os.WriteFile(path, []byte("this is not a GGUF file at all"), 0644) - require.NoError(t, err) + if err := os.WriteFile(path, []byte("this is not a GGUF file at all"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } - _, err = ReadMetadata(path) - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid magic") + _, err := ReadMetadata(path) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "invalid magic") { + t.Errorf("err = %v, want contains %q", err, "invalid magic") + } } func TestReadMetadata_FileNotFound(t *testing.T) { _, err := ReadMetadata("/nonexistent/path/model.gguf") - require.Error(t, err) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "open file") { + t.Errorf("err = %v, want contains %q", err, "open file") + } } func TestFileTypeName(t *testing.T) { - assert.Equal(t, "Q4_K_M", FileTypeName(15)) - assert.Equal(t, "Q5_K_M", FileTypeName(17)) - assert.Equal(t, "Q8_0", FileTypeName(7)) - assert.Equal(t, "F16", FileTypeName(1)) - assert.Equal(t, "type_999", FileTypeName(999)) + cases := []struct { + ft uint32 + want string + }{ + {15, "Q4_K_M"}, + {17, "Q5_K_M"}, + {7, "Q8_0"}, + {1, "F16"}, + {999, "type_999"}, + } + for _, c := range cases { + if got := FileTypeName(c.ft); got != c.want { + t.Errorf("FileTypeName(%d) = %q, want %q", c.ft, got, c.want) + } + } } func TestReadMetadata_V2(t *testing.T) { @@ -210,13 +312,25 @@ func TestReadMetadata_V2(t *testing.T) { }) m, err := ReadMetadata(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } - assert.Equal(t, "llama", m.Architecture) - assert.Equal(t, "V2 Model", m.Name) - assert.Equal(t, uint32(15), m.FileType) - assert.Equal(t, uint32(2048), m.ContextLength) - assert.Equal(t, uint32(16), m.BlockCount) + if m.Architecture != "llama" { + t.Errorf("Architecture = %q, want %q", m.Architecture, "llama") + } + if m.Name != "V2 Model" { + t.Errorf("Name = %q, want %q", m.Name, "V2 Model") + } + if m.FileType != uint32(15) { + t.Errorf("FileType = %d, want 15", m.FileType) + } + if m.ContextLength != uint32(2048) { + t.Errorf("ContextLength = %d, want 2048", m.ContextLength) + } + if m.BlockCount != uint32(16) { + t.Errorf("BlockCount = %d, want 16", m.BlockCount) + } } func TestReadMetadata_UnsupportedVersion(t *testing.T) { @@ -224,15 +338,25 @@ func TestReadMetadata_UnsupportedVersion(t *testing.T) { path := filepath.Join(dir, "bad_version.gguf") f, err := os.Create(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("os.Create: %v", err) + } - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) // magic - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(99))) // invalid version + if err := binary.Write(f, binary.LittleEndian, uint32(0x46554747)); err != nil { + t.Fatalf("write magic: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, uint32(99)); err != nil { + t.Fatalf("write version: %v", err) + } f.Close() _, err = ReadMetadata(path) - require.Error(t, err) - assert.Contains(t, err.Error(), "unsupported GGUF version") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unsupported GGUF version") { + t.Errorf("err = %v, want contains %q", err, "unsupported GGUF version") + } } func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) { @@ -242,14 +366,24 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) { path := filepath.Join(dir, "skip_types.gguf") f, err := os.Create(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("os.Create: %v", err) + } // Header: magic, v3, 0 tensors - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(3))) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(0))) + if err := binary.Write(f, binary.LittleEndian, uint32(0x46554747)); err != nil { + t.Fatalf("write magic: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { + t.Fatalf("write version: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatalf("write tensor count: %v", err) + } // 8 KV pairs: 6 skip types + 2 interesting keys - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(8))) + if err := binary.Write(f, binary.LittleEndian, uint64(8)); err != nil { + t.Fatalf("write kv count: %v", err) + } // 1. uint8 (type 0) — 1 byte raw := make([]byte, 1) @@ -283,7 +417,7 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) { b8 := make([]byte, 8) binary.LittleEndian.PutUint64(b8, 3) // count: 3 arrBuf = append(arrBuf, b8...) - arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values + arrBuf = append(arrBuf, 10, 20, 30) // 3 uint8 values writeRawKV(t, f, "custom.array_val", 9, arrBuf) // 7-8. Interesting keys to verify parsing continued correctly. @@ -293,10 +427,16 @@ func TestReadMetadata_SkipsUnknownValueTypes(t *testing.T) { f.Close() m, err := ReadMetadata(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } - assert.Equal(t, "llama", m.Architecture) - assert.Equal(t, "Skip Test Model", m.Name) + if m.Architecture != "llama" { + t.Errorf("Architecture = %q, want %q", m.Architecture, "llama") + } + if m.Name != "Skip Test Model" { + t.Errorf("Name = %q, want %q", m.Name, "Skip Test Model") + } } func TestReadMetadata_Uint64ContextLength(t *testing.T) { @@ -309,10 +449,16 @@ func TestReadMetadata_Uint64ContextLength(t *testing.T) { }) m, err := ReadMetadata(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } - assert.Equal(t, uint32(8192), m.ContextLength) - assert.Equal(t, uint32(32), m.BlockCount) + if m.ContextLength != uint32(8192) { + t.Errorf("ContextLength = %d, want 8192", m.ContextLength) + } + if m.BlockCount != uint32(32) { + t.Errorf("BlockCount = %d, want 32", m.BlockCount) + } } func TestReadMetadata_TruncatedFile(t *testing.T) { @@ -321,13 +467,21 @@ func TestReadMetadata_TruncatedFile(t *testing.T) { // Write only the magic — no version or counts. f, err := os.Create(path) - require.NoError(t, err) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) + if err != nil { + t.Fatalf("os.Create: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, uint32(0x46554747)); err != nil { + t.Fatalf("write magic: %v", err) + } f.Close() _, err = ReadMetadata(path) - require.Error(t, err) - assert.Contains(t, err.Error(), "reading version") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "reading version") { + t.Errorf("err = %v, want contains %q", err, "reading version") + } } func TestReadMetadata_SkipsStringValue(t *testing.T) { @@ -336,12 +490,22 @@ func TestReadMetadata_SkipsStringValue(t *testing.T) { path := filepath.Join(dir, "skip_string.gguf") f, err := os.Create(path) - require.NoError(t, err) + if err != nil { + t.Fatalf("os.Create: %v", err) + } - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(0x46554747))) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint32(3))) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(0))) - require.NoError(t, binary.Write(f, binary.LittleEndian, uint64(2))) + if err := binary.Write(f, binary.LittleEndian, uint32(0x46554747)); err != nil { + t.Fatalf("write magic: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { + t.Fatalf("write version: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatalf("write tensor count: %v", err) + } + if err := binary.Write(f, binary.LittleEndian, uint64(2)); err != nil { + t.Fatalf("write kv count: %v", err) + } // Uninteresting string key (exercises skipValue for typeString). writeKV(t, f, "custom.description", "a long description value") @@ -351,6 +515,10 @@ func TestReadMetadata_SkipsStringValue(t *testing.T) { f.Close() m, err := ReadMetadata(path) - require.NoError(t, err) - assert.Equal(t, "gemma3", m.Architecture) + if err != nil { + t.Fatalf("ReadMetadata: %v", err) + } + if m.Architecture != "gemma3" { + t.Errorf("Architecture = %q, want %q", m.Architecture, "gemma3") + } } diff --git a/internal/llamacpp/ax7_test.go b/internal/llamacpp/ax7_test.go new file mode 100644 index 0000000..28dcddb --- /dev/null +++ b/internal/llamacpp/ax7_test.go @@ -0,0 +1,252 @@ +package llamacpp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" +) + +func TestClient_NewClient_Good(t *testing.T) { + client := NewClient("http://127.0.0.1:38080") + if client.baseURL != "http://127.0.0.1:38080" { + t.Fatalf("baseURL = %q, want trimmed base URL", client.baseURL) + } + if client.httpClient == nil { + t.Fatal("httpClient = nil, want default client") + } +} + +func TestClient_NewClient_Bad(t *testing.T) { + client := NewClient("") + if client.baseURL != "" { + t.Fatalf("baseURL = %q, want empty base URL preserved", client.baseURL) + } + if client.httpClient == nil { + t.Fatal("httpClient = nil, want default client") + } +} + +func TestClient_NewClient_Ugly(t *testing.T) { + client := NewClient("http://127.0.0.1:38080///") + if client.baseURL != "http://127.0.0.1:38080" { + t.Fatalf("baseURL = %q, want all trailing slashes trimmed", client.baseURL) + } + if client.httpClient == nil { + t.Fatal("httpClient = nil, want default client") + } +} + +func TestClient_NewClientWithHTTPClient_Good(t *testing.T) { + httpClient := &http.Client{} + client := NewClientWithHTTPClient("http://llama.test/", httpClient) + if client.httpClient != httpClient { + t.Fatal("httpClient was not preserved") + } + if client.baseURL != "http://llama.test" { + t.Fatalf("baseURL = %q, want trimmed base URL", client.baseURL) + } +} + +func TestClient_NewClientWithHTTPClient_Bad(t *testing.T) { + client := NewClientWithHTTPClient("http://llama.test", nil) + if client.httpClient == nil { + t.Fatal("httpClient = nil, want default client") + } + if client.baseURL != "http://llama.test" { + t.Fatalf("baseURL = %q, want original base URL", client.baseURL) + } +} + +func TestClient_NewClientWithHTTPClient_Ugly(t *testing.T) { + httpClient := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("unused") + })} + client := NewClientWithHTTPClient("http://llama.test////", httpClient) + if client.baseURL != "http://llama.test" { + t.Fatalf("baseURL = %q, want deeply trimmed base URL", client.baseURL) + } + if client.httpClient.Transport == nil { + t.Fatal("transport = nil, want injected transport") + } +} + +func TestClient_Client_Health_Good(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/health" { + t.Fatalf("path = %q, want /health", r.URL.Path) + } + _, _ = w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + client := NewClient(ts.URL) + if err := client.Health(context.Background()); err != nil { + t.Fatalf("Health() = %v, want nil", err) + } +} + +func TestClient_Client_Health_Bad(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "loading", http.StatusServiceUnavailable) + })) + defer ts.Close() + + client := NewClient(ts.URL) + err := client.Health(context.Background()) + if err == nil { + t.Fatal("Health() error = nil, want non-200 error") + } + if !strings.Contains(err.Error(), "503") { + t.Fatalf("Health() = %v, want 503", err) + } +} + +func TestClient_Client_Health_Ugly(t *testing.T) { + client := NewClient("http://%zz") + err := client.Health(context.Background()) + if err == nil { + t.Fatal("Health() error = nil, want request creation error") + } + if !strings.Contains(err.Error(), "create health request") { + t.Fatalf("Health() = %v, want create health request", err) + } +} + +func TestClient_Client_Complete_Good(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/completions" { + t.Fatalf("path = %q, want /v1/completions", r.URL.Path) + } + sseLines(w, []string{ + `{"choices":[{"text":"A","finish_reason":null}]}`, + `{"choices":[{"text":"B","finish_reason":null}]}`, + "[DONE]", + }) + })) + defer ts.Close() + + client := NewClient(ts.URL) + tokens, errFn := client.Complete(context.Background(), CompletionRequest{Prompt: "go"}) + got := collectStringSeq(tokens) + if err := errFn(); err != nil { + t.Fatalf("errFn() = %v, want nil", err) + } + if !reflect.DeepEqual(got, []string{"A", "B"}) { + t.Fatalf("tokens = %v, want A/B", got) + } +} + +func TestClient_Client_Complete_Bad(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer ts.Close() + + client := NewClient(ts.URL) + tokens, errFn := client.Complete(context.Background(), CompletionRequest{Prompt: "go"}) + got := collectStringSeq(tokens) + if len(got) != 0 { + t.Fatalf("tokens = %v, want empty", got) + } + if err := errFn(); err == nil || !strings.Contains(err.Error(), "400") { + t.Fatalf("errFn() = %v, want 400", err) + } +} + +func TestClient_Client_Complete_Ugly(t *testing.T) { + client := NewClientWithHTTPClient("http://llama.test", &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("data: {\"choices\":[{\"text\":\"partial\"}]}\n\n")), + Request: r, + }, nil + }), + }) + + tokens, errFn := client.Complete(context.Background(), CompletionRequest{Prompt: "go"}) + got := collectStringSeq(tokens) + if !reflect.DeepEqual(got, []string{"partial"}) { + t.Fatalf("tokens = %v, want partial", got) + } + if err := errFn(); err == nil || !strings.Contains(err.Error(), "stream ended before [DONE]") { + t.Fatalf("errFn() = %v, want truncated stream", err) + } +} + +func TestClient_Client_ChatComplete_Good(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Fatalf("path = %q, want /v1/chat/completions", r.URL.Path) + } + sseLines(w, []string{ + `{"choices":[{"delta":{"content":"hi"},"finish_reason":null}]}`, + `{"choices":[{"delta":{"content":" there"},"finish_reason":null}]}`, + "[DONE]", + }) + })) + defer ts.Close() + + client := NewClient(ts.URL) + tokens, errFn := client.ChatComplete(context.Background(), ChatRequest{Messages: []ChatMessage{{Role: "user", Content: "hi"}}}) + got := collectStringSeq(tokens) + if err := errFn(); err != nil { + t.Fatalf("errFn() = %v, want nil", err) + } + if !reflect.DeepEqual(got, []string{"hi", " there"}) { + t.Fatalf("tokens = %v, want chat chunks", got) + } +} + +func TestClient_Client_ChatComplete_Bad(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "chat failed", http.StatusInternalServerError) + })) + defer ts.Close() + + client := NewClient(ts.URL) + tokens, errFn := client.ChatComplete(context.Background(), ChatRequest{Messages: []ChatMessage{{Role: "user", Content: "hi"}}}) + got := collectStringSeq(tokens) + if len(got) != 0 { + t.Fatalf("tokens = %v, want empty", got) + } + if err := errFn(); err == nil || !strings.Contains(err.Error(), "500") { + t.Fatalf("errFn() = %v, want 500", err) + } +} + +func TestClient_Client_ChatComplete_Ugly(t *testing.T) { + client := NewClientWithHTTPClient("http://llama.test", &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("data: {\"choices\":[{\"delta\":{\"content\":\"partial\"}}]}\n\n")), + Request: r, + }, nil + }), + }) + + tokens, errFn := client.ChatComplete(context.Background(), ChatRequest{Messages: []ChatMessage{{Role: "user", Content: "hi"}}}) + got := collectStringSeq(tokens) + if !reflect.DeepEqual(got, []string{"partial"}) { + t.Fatalf("tokens = %v, want partial", got) + } + if err := errFn(); err == nil || !strings.Contains(err.Error(), "stream ended before [DONE]") { + t.Fatalf("errFn() = %v, want truncated stream", err) + } +} + +func collectStringSeq(seq func(func(string) bool)) []string { + var got []string + for token := range seq { + got = append(got, token) + } + return got +} diff --git a/internal/llamacpp/client.go b/internal/llamacpp/client.go index 2fb0c11..cfb4871 100644 --- a/internal/llamacpp/client.go +++ b/internal/llamacpp/client.go @@ -12,7 +12,7 @@ import ( "strings" "sync" - coreerr "forge.lthn.ai/core/go-log" + coreerr "dappco.re/go/log" ) // ChatMessage is a single message in a conversation. @@ -43,7 +43,7 @@ type CompletionRequest struct { Stream bool `json:"stream"` } -type chatChunkResponse struct { +type chatStreamChunkResponse struct { Choices []struct { Delta struct { Content string `json:"content"` @@ -52,56 +52,61 @@ type chatChunkResponse struct { } `json:"choices"` } -type completionChunkResponse struct { +type completionStreamChunkResponse struct { Choices []struct { Text string `json:"text"` FinishReason *string `json:"finish_reason"` } `json:"choices"` } -// ChatComplete sends a streaming chat completion request to /v1/chat/completions. -// It returns an iterator over text chunks and a function that returns any error -// that occurred during the request or while reading the stream. +// chunks, streamError := client.ChatComplete(ctx, ChatRequest{ +// Messages: []ChatMessage{{Role: "user", Content: "Hi"}}, +// }) +// +// ChatComplete sends a streaming chat completion request to +// /v1/chat/completions. It returns an iterator over text chunks and a function +// that returns any error that occurred during the request or while reading the +// stream. func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[string], func() error) { req.Stream = true - body, err := json.Marshal(req) + requestBody, err := json.Marshal(req) if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) } + return noStreamChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "marshal chat request", err) } } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(body)) + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/chat/completions", bytes.NewReader(requestBody)) if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) } + return noStreamChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "create chat request", err) } } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "text/event-stream") + httpRequest.Header.Set("Content-Type", "application/json") + httpRequest.Header.Set("Accept", "text/event-stream") - resp, err := c.httpClient.Do(httpReq) + response, err := c.httpClient.Do(httpRequest) if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) } + return noStreamChunks, func() error { return coreerr.E("llamacpp.ChatComplete", "chat request", err) } } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return noChunks, func() error { - return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil) + if response.StatusCode != http.StatusOK { + defer response.Body.Close() + responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256)) + return noStreamChunks, func() error { + return coreerr.E("llamacpp.ChatComplete", fmt.Sprintf("chat returned %d: %s", response.StatusCode, strings.TrimSpace(string(responseBody))), nil) } } var ( streamErr error closeOnce sync.Once - closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) } + closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) } ) - sseData := parseSSE(resp.Body, &streamErr) + eventDataStream := streamSSEData(response.Body, &streamErr) - tokens := func(yield func(string) bool) { + tokenStream := func(yield func(string) bool) { defer closeBody() - for raw := range sseData { - var chunk chatChunkResponse - if err := json.Unmarshal([]byte(raw), &chunk); err != nil { + for rawChunk := range eventDataStream { + var chunk chatStreamChunkResponse + if err := json.Unmarshal([]byte(rawChunk), &chunk); err != nil { streamErr = coreerr.E("llamacpp.ChatComplete", "decode chat chunk", err) return } @@ -118,55 +123,59 @@ func (c *Client) ChatComplete(ctx context.Context, req ChatRequest) (iter.Seq[st } } - return tokens, func() error { + return tokenStream, func() error { closeBody() return streamErr } } -// Complete sends a streaming completion request to /v1/completions. -// It returns an iterator over text chunks and a function that returns any error +// chunks, streamError := client.Complete(ctx, CompletionRequest{ +// Prompt: "Hello", +// }) +// +// Complete sends a streaming completion request to /v1/completions. It +// returns an iterator over text chunks and a function that returns any error // that occurred during the request or while reading the stream. func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[string], func() error) { req.Stream = true - body, err := json.Marshal(req) + requestBody, err := json.Marshal(req) if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) } + return noStreamChunks, func() error { return coreerr.E("llamacpp.Complete", "marshal completion request", err) } } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(body)) + httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/completions", bytes.NewReader(requestBody)) if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) } + return noStreamChunks, func() error { return coreerr.E("llamacpp.Complete", "create completion request", err) } } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "text/event-stream") + httpRequest.Header.Set("Content-Type", "application/json") + httpRequest.Header.Set("Accept", "text/event-stream") - resp, err := c.httpClient.Do(httpReq) + response, err := c.httpClient.Do(httpRequest) if err != nil { - return noChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) } + return noStreamChunks, func() error { return coreerr.E("llamacpp.Complete", "completion request", err) } } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return noChunks, func() error { - return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))), nil) + if response.StatusCode != http.StatusOK { + defer response.Body.Close() + responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256)) + return noStreamChunks, func() error { + return coreerr.E("llamacpp.Complete", fmt.Sprintf("completion returned %d: %s", response.StatusCode, strings.TrimSpace(string(responseBody))), nil) } } var ( streamErr error closeOnce sync.Once - closeBody = func() { closeOnce.Do(func() { resp.Body.Close() }) } + closeBody = func() { closeOnce.Do(func() { response.Body.Close() }) } ) - sseData := parseSSE(resp.Body, &streamErr) + eventDataStream := streamSSEData(response.Body, &streamErr) - tokens := func(yield func(string) bool) { + tokenStream := func(yield func(string) bool) { defer closeBody() - for raw := range sseData { - var chunk completionChunkResponse - if err := json.Unmarshal([]byte(raw), &chunk); err != nil { + for rawChunk := range eventDataStream { + var chunk completionStreamChunkResponse + if err := json.Unmarshal([]byte(rawChunk), &chunk); err != nil { streamErr = coreerr.E("llamacpp.Complete", "decode completion chunk", err) return } @@ -183,18 +192,19 @@ func (c *Client) Complete(ctx context.Context, req CompletionRequest) (iter.Seq[ } } - return tokens, func() error { + return tokenStream, func() error { closeBody() return streamErr } } -// parseSSE reads SSE-formatted lines from r and yields the payload of each -// "data: " line. It stops when it encounters "[DONE]" or an I/O error. -// Any read error (other than EOF) is stored via errOut. -func parseSSE(r io.Reader, errOut *error) iter.Seq[string] { +// streamSSEData reads SSE-formatted lines from r and yields the payload of +// each "data: " line. llama-server terminates successful streams with a +// "[DONE]" sentinel; EOF before that marker is treated as a truncated stream. +func streamSSEData(r io.Reader, errOut *error) iter.Seq[string] { return func(yield func(string) bool) { scanner := bufio.NewScanner(r) + sawDone := false for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { @@ -202,6 +212,7 @@ func parseSSE(r io.Reader, errOut *error) iter.Seq[string] { } payload := strings.TrimPrefix(line, "data: ") if payload == "[DONE]" { + sawDone = true return } if !yield(payload) { @@ -209,10 +220,15 @@ func parseSSE(r io.Reader, errOut *error) iter.Seq[string] { } } if err := scanner.Err(); err != nil { - *errOut = coreerr.E("llamacpp.parseSSE", "read SSE stream", err) + *errOut = coreerr.E("llamacpp.streamSSEData", "read SSE stream", err) + return + } + if !sawDone { + *errOut = coreerr.E("llamacpp.streamSSEData", "stream ended before [DONE]", io.ErrUnexpectedEOF) } } } -// noChunks is an empty iterator returned when an error occurs before streaming begins. -func noChunks(func(string) bool) {} +// noStreamChunks is an empty iterator returned when an error occurs before +// streaming begins. +func noStreamChunks(func(string) bool) {} diff --git a/internal/llamacpp/client_test.go b/internal/llamacpp/client_test.go index 20b49c7..8bd37ac 100644 --- a/internal/llamacpp/client_test.go +++ b/internal/llamacpp/client_test.go @@ -3,14 +3,20 @@ package llamacpp import ( "context" "fmt" + "io" "net/http" "net/http/httptest" + "reflect" + "strings" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return fn(r) +} + // sseLines writes SSE-formatted lines to a flushing response writer. func sseLines(w http.ResponseWriter, lines []string) { f, ok := w.(http.Flusher) @@ -29,8 +35,12 @@ func sseLines(w http.ResponseWriter, lines []string) { func TestChatComplete_Streaming(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/v1/chat/completions", r.URL.Path) - assert.Equal(t, "POST", r.Method) + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("r.URL.Path = %q, want %q", r.URL.Path, "/v1/chat/completions") + } + if r.Method != "POST" { + t.Errorf("r.Method = %q, want %q", r.Method, "POST") + } sseLines(w, []string{ `{"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}`, `{"choices":[{"delta":{"content":" world"},"finish_reason":null}]}`, @@ -52,8 +62,13 @@ func TestChatComplete_Streaming(t *testing.T) { for tok := range tokens { got = append(got, tok) } - require.NoError(t, errFn()) - assert.Equal(t, []string{"Hello", " world"}, got) + if err := errFn(); err != nil { + t.Fatalf("errFn: %v", err) + } + want := []string{"Hello", " world"} + if !reflect.DeepEqual(got, want) { + t.Errorf("tokens = %v, want %v", got, want) + } } func TestChatComplete_EmptyResponse(t *testing.T) { @@ -73,8 +88,12 @@ func TestChatComplete_EmptyResponse(t *testing.T) { for tok := range tokens { got = append(got, tok) } - require.NoError(t, errFn()) - assert.Empty(t, got) + if err := errFn(); err != nil { + t.Fatalf("errFn: %v", err) + } + if len(got) != 0 { + t.Errorf("got = %v, want empty", got) + } } func TestChatComplete_HTTPError(t *testing.T) { @@ -94,10 +113,16 @@ func TestChatComplete_HTTPError(t *testing.T) { for tok := range tokens { got = append(got, tok) } - assert.Empty(t, got) + if len(got) != 0 { + t.Errorf("got = %v, want empty", got) + } err := errFn() - require.Error(t, err) - assert.Contains(t, err.Error(), "500") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("err = %v, want contains %q", err, "500") + } } func TestChatComplete_ContextCancelled(t *testing.T) { @@ -137,13 +162,20 @@ func TestChatComplete_ContextCancelled(t *testing.T) { // The error may or may not be nil depending on timing; // the important thing is we got exactly 1 token. _ = errFn() - assert.Equal(t, []string{"Hello"}, got) + want := []string{"Hello"} + if !reflect.DeepEqual(got, want) { + t.Errorf("tokens = %v, want %v", got, want) + } } func TestComplete_Streaming(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/v1/completions", r.URL.Path) - assert.Equal(t, "POST", r.Method) + if r.URL.Path != "/v1/completions" { + t.Errorf("r.URL.Path = %q, want %q", r.URL.Path, "/v1/completions") + } + if r.Method != "POST" { + t.Errorf("r.Method = %q, want %q", r.Method, "POST") + } sseLines(w, []string{ `{"choices":[{"text":"Once","finish_reason":null}]}`, `{"choices":[{"text":" upon","finish_reason":null}]}`, @@ -166,8 +198,13 @@ func TestComplete_Streaming(t *testing.T) { for tok := range tokens { got = append(got, tok) } - require.NoError(t, errFn()) - assert.Equal(t, []string{"Once", " upon", " a time"}, got) + if err := errFn(); err != nil { + t.Fatalf("errFn: %v", err) + } + want := []string{"Once", " upon", " a time"} + if !reflect.DeepEqual(got, want) { + t.Errorf("tokens = %v, want %v", got, want) + } } func TestComplete_HTTPError(t *testing.T) { @@ -187,8 +224,54 @@ func TestComplete_HTTPError(t *testing.T) { for tok := range tokens { got = append(got, tok) } - assert.Empty(t, got) + if len(got) != 0 { + t.Errorf("got = %v, want empty", got) + } err := errFn() - require.Error(t, err) - assert.Contains(t, err.Error(), "400") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "400") { + t.Errorf("err = %v, want contains %q", err, "400") + } +} + +func TestComplete_TruncatedStreamReturnsError(t *testing.T) { + c := NewClientWithHTTPClient("http://llama.test", &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.Path != "/v1/completions" { + return nil, fmt.Errorf("unexpected path %q", r.URL.Path) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: " + `{"choices":[{"text":"partial","finish_reason":null}]}` + "\n\n", + )), + Request: r, + }, nil + }), + }) + tokens, errFn := c.Complete(context.Background(), CompletionRequest{ + Prompt: "Hello", + Temperature: 0.0, + Stream: true, + }) + + var got []string + for tok := range tokens { + got = append(got, tok) + } + + want := []string{"partial"} + if !reflect.DeepEqual(got, want) { + t.Errorf("tokens = %v, want %v", got, want) + } + err := errFn() + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "stream ended before [DONE]") { + t.Errorf("err = %v, want contains %q", err, "stream ended before [DONE]") + } } diff --git a/internal/llamacpp/health.go b/internal/llamacpp/health.go index 33ec57b..592fe6b 100644 --- a/internal/llamacpp/health.go +++ b/internal/llamacpp/health.go @@ -8,7 +8,7 @@ import ( "net/http" "strings" - coreerr "forge.lthn.ai/core/go-log" + coreerr "dappco.re/go/log" ) // Client communicates with a llama-server instance. @@ -17,40 +17,55 @@ type Client struct { httpClient *http.Client } +// client := NewClient("http://127.0.0.1:38080") +// // NewClient creates a client for the llama-server at the given base URL. func NewClient(baseURL string) *Client { + return NewClientWithHTTPClient(baseURL, &http.Client{}) +} + +// client := NewClientWithHTTPClient("http://127.0.0.1:38080", customHTTPClient) +// +// NewClientWithHTTPClient creates a client with an injected HTTP transport. +func NewClientWithHTTPClient(baseURL string, httpClient *http.Client) *Client { + if httpClient == nil { + httpClient = &http.Client{} + } return &Client{ baseURL: strings.TrimRight(baseURL, "/"), - httpClient: &http.Client{}, + httpClient: httpClient, } } -type healthResponse struct { +type healthStatusResponse struct { Status string `json:"status"` } +// err := client.Health(ctx) +// fmt.Println(err == nil) +// // Health checks whether the llama-server is ready to accept requests. func (c *Client) Health(ctx context.Context) error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil) + request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/health", nil) if err != nil { - return err + return coreerr.E("llamacpp.Health", "create health request", err) } - resp, err := c.httpClient.Do(req) + response, err := c.httpClient.Do(request) if err != nil { - return err + return coreerr.E("llamacpp.Health", "health request", err) } - defer resp.Body.Close() + defer response.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", resp.StatusCode, string(body)), nil) + if response.StatusCode != http.StatusOK { + responseBody, _ := io.ReadAll(io.LimitReader(response.Body, 256)) + return coreerr.E("llamacpp.Health", fmt.Sprintf("health returned %d: %s", response.StatusCode, string(responseBody)), nil) } - var h healthResponse - if err := json.NewDecoder(resp.Body).Decode(&h); err != nil { + var healthStatus healthStatusResponse + if err := json.NewDecoder(response.Body).Decode(&healthStatus); err != nil { return coreerr.E("llamacpp.Health", "health decode", err) } - if h.Status != "ok" { - return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", h.Status), nil) + if healthStatus.Status != "ok" { + return coreerr.E("llamacpp.Health", fmt.Sprintf("server not ready (status: %s)", healthStatus.Status), nil) } return nil } diff --git a/internal/llamacpp/health_test.go b/internal/llamacpp/health_test.go index 38affcf..f79c4c7 100644 --- a/internal/llamacpp/health_test.go +++ b/internal/llamacpp/health_test.go @@ -4,23 +4,24 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestHealth_OK(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/health", r.URL.Path) + if r.URL.Path != "/health" { + t.Errorf("r.URL.Path = %q, want %q", r.URL.Path, "/health") + } w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"status":"ok"}`)) })) defer ts.Close() c := NewClient(ts.URL) - err := c.Health(context.Background()) - require.NoError(t, err) + if err := c.Health(context.Background()); err != nil { + t.Fatalf("Health: %v", err) + } } func TestHealth_NotReady(t *testing.T) { @@ -32,7 +33,12 @@ func TestHealth_NotReady(t *testing.T) { c := NewClient(ts.URL) err := c.Health(context.Background()) - assert.ErrorContains(t, err, "not ready") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "not ready") { + t.Errorf("err = %v, want contains %q", err, "not ready") + } } func TestHealth_Loading(t *testing.T) { @@ -45,11 +51,32 @@ func TestHealth_Loading(t *testing.T) { c := NewClient(ts.URL) err := c.Health(context.Background()) - assert.ErrorContains(t, err, "503") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "503") { + t.Errorf("err = %v, want contains %q", err, "503") + } } func TestHealth_ServerDown(t *testing.T) { c := NewClient("http://127.0.0.1:1") // nothing listening err := c.Health(context.Background()) - assert.Error(t, err) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "health request") { + t.Errorf("err = %v, want contains %q", err, "health request") + } +} + +func TestHealth_InvalidBaseURL(t *testing.T) { + c := NewClient("http://%zz") + err := c.Health(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "create health request") { + t.Errorf("err = %v, want contains %q", err, "create health request") + } } diff --git a/model.go b/model.go index 5ae5c47..71a53a7 100644 --- a/model.go +++ b/model.go @@ -10,77 +10,68 @@ import ( "sync" "time" - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-inference" - "forge.lthn.ai/core/go-rocm/internal/llamacpp" + "dappco.re/go/inference" + coreerr "dappco.re/go/log" + "dappco.re/go/rocm/internal/llamacpp" ) // rocmModel implements inference.TextModel using a llama-server subprocess. type rocmModel struct { - srv *server + server *server modelType string modelInfo inference.ModelInfo - mu sync.Mutex - lastErr error - metrics inference.GenerateMetrics + stateMutex sync.Mutex + lastError error + lastMetrics inference.GenerateMetrics } // Generate streams tokens for the given prompt via llama-server's /v1/completions endpoint. func (m *rocmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] { - m.mu.Lock() - m.lastErr = nil - m.mu.Unlock() + m.clearLastError() - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return func(yield func(inference.Token) bool) {} } - cfg := inference.ApplyGenerateOpts(opts) - - req := llamacpp.CompletionRequest{ - Prompt: prompt, - MaxTokens: cfg.MaxTokens, - Temperature: cfg.Temperature, - TopK: cfg.TopK, - TopP: cfg.TopP, - RepeatPenalty: cfg.RepeatPenalty, - } + generateConfig := inference.ApplyGenerateOpts(opts) + request := newCompletionRequest(prompt, generateConfig) + promptTokens := approximatePromptTokens(prompt) start := time.Now() - chunks, errFn := m.srv.client.Complete(ctx, req) + chunks, streamError := m.server.llamaClient.Complete(ctx, request) return func(yield func(inference.Token) bool) { var count int - decodeStart := time.Now() + var firstTokenAt time.Time for text := range chunks { + if firstTokenAt.IsZero() { + firstTokenAt = time.Now() + } count++ if !yield(inference.Token{Text: text}) { break } } - if err := errFn(); err != nil { - m.mu.Lock() - m.lastErr = err - m.mu.Unlock() + if err := streamError(); err != nil { + m.setLastError(err) } - m.recordMetrics(0, count, start, decodeStart) + m.recordMetrics(promptTokens, count, start, firstTokenAt) } } // Chat streams tokens from a multi-turn conversation via llama-server's /v1/chat/completions endpoint. func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { - m.mu.Lock() - m.lastErr = nil - m.mu.Unlock() + m.clearLastError() - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return func(yield func(inference.Token) bool) {} } - cfg := inference.ApplyGenerateOpts(opts) + generateConfig := inference.ApplyGenerateOpts(opts) + promptTokens := approximateMessageTokens(messages) chatMsgs := make([]llamacpp.ChatMessage, len(messages)) for i, msg := range messages { @@ -89,119 +80,136 @@ func (m *rocmModel) Chat(ctx context.Context, messages []inference.Message, opts Content: msg.Content, } } - - req := llamacpp.ChatRequest{ - Messages: chatMsgs, - MaxTokens: cfg.MaxTokens, - Temperature: cfg.Temperature, - TopK: cfg.TopK, - TopP: cfg.TopP, - RepeatPenalty: cfg.RepeatPenalty, - } + request := newChatRequest(chatMsgs, generateConfig) start := time.Now() - chunks, errFn := m.srv.client.ChatComplete(ctx, req) + chunks, streamError := m.server.llamaClient.ChatComplete(ctx, request) return func(yield func(inference.Token) bool) { var count int - decodeStart := time.Now() + var firstTokenAt time.Time for text := range chunks { + if firstTokenAt.IsZero() { + firstTokenAt = time.Now() + } count++ if !yield(inference.Token{Text: text}) { break } } - if err := errFn(); err != nil { - m.mu.Lock() - m.lastErr = err - m.mu.Unlock() + if err := streamError(); err != nil { + m.setLastError(err) } - m.recordMetrics(0, count, start, decodeStart) + m.recordMetrics(promptTokens, count, start, firstTokenAt) } } // Classify runs batched prefill-only inference via llama-server. -// Each prompt gets a single-token completion (max_tokens=1, temperature=0). -// llama-server has no native classify endpoint, so this simulates it. +// Each prompt gets a single-token completion (max_tokens=1) while honoring +// the sampling settings from opts. llama-server has no native classify +// endpoint, so this simulates it. func (m *rocmModel) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) { - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return nil, m.Err() } - start := time.Now() + generateConfig := inference.ApplyGenerateOpts(opts) results := make([]inference.ClassifyResult, len(prompts)) - - for i, prompt := range prompts { - if ctx.Err() != nil { - return nil, ctx.Err() + totalPromptTokens := 0 + totalGenerated := 0 + var totalPrefill time.Duration + var totalDecode time.Duration + + for promptIndex, prompt := range prompts { + if contextError := ctx.Err(); contextError != nil { + m.recordMetricsDurations(totalPromptTokens, totalGenerated, totalPrefill, totalDecode) + return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify cancelled before prompt %d", promptIndex), contextError) } - req := llamacpp.CompletionRequest{ - Prompt: prompt, - MaxTokens: 1, - Temperature: 0, - } + totalPromptTokens += approximatePromptTokens(prompt) + request := newCompletionRequest(prompt, generateConfig) + request.MaxTokens = 1 - chunks, errFn := m.srv.client.Complete(ctx, req) + requestStart := time.Now() + chunks, streamError := m.server.llamaClient.Complete(ctx, request) var text strings.Builder + var firstTokenAt time.Time + var generated int for chunk := range chunks { + if firstTokenAt.IsZero() { + firstTokenAt = time.Now() + } + generated++ text.WriteString(chunk) } - if err := errFn(); err != nil { - return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", i), err) + requestEnd := time.Now() + prefill, decode := splitDurations(requestStart, firstTokenAt, requestEnd) + totalPrefill += prefill + totalDecode += decode + totalGenerated += generated + + if err := streamError(); err != nil { + m.recordMetricsDurations(totalPromptTokens, totalGenerated, totalPrefill, totalDecode) + return nil, coreerr.E("rocm.Classify", fmt.Sprintf("classify prompt %d", promptIndex), err) } - results[i] = inference.ClassifyResult{ + results[promptIndex] = inference.ClassifyResult{ Token: inference.Token{Text: text.String()}, } } - m.recordMetrics(len(prompts), len(prompts), start, start) + m.recordMetricsDurations(totalPromptTokens, totalGenerated, totalPrefill, totalDecode) return results, nil } // BatchGenerate runs batched autoregressive generation via llama-server. // Each prompt is decoded sequentially up to MaxTokens. func (m *rocmModel) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) { - if !m.srv.alive() { + if !m.server.alive() { m.setServerExitErr() return nil, m.Err() } - cfg := inference.ApplyGenerateOpts(opts) - start := time.Now() + generateConfig := inference.ApplyGenerateOpts(opts) results := make([]inference.BatchResult, len(prompts)) + totalPromptTokens := 0 var totalGenerated int + var totalPrefill time.Duration + var totalDecode time.Duration - for i, prompt := range prompts { - if ctx.Err() != nil { - results[i].Err = ctx.Err() + for promptIndex, prompt := range prompts { + if contextError := ctx.Err(); contextError != nil { + results[promptIndex].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d cancelled before start", promptIndex), contextError) continue } - req := llamacpp.CompletionRequest{ - Prompt: prompt, - MaxTokens: cfg.MaxTokens, - Temperature: cfg.Temperature, - TopK: cfg.TopK, - TopP: cfg.TopP, - RepeatPenalty: cfg.RepeatPenalty, - } + totalPromptTokens += approximatePromptTokens(prompt) + request := newCompletionRequest(prompt, generateConfig) - chunks, errFn := m.srv.client.Complete(ctx, req) + requestStart := time.Now() + chunks, streamError := m.server.llamaClient.Complete(ctx, request) var tokens []inference.Token + var firstTokenAt time.Time for text := range chunks { + if firstTokenAt.IsZero() { + firstTokenAt = time.Now() + } tokens = append(tokens, inference.Token{Text: text}) } - if err := errFn(); err != nil { - results[i].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", i), err) - } - results[i].Tokens = tokens + requestEnd := time.Now() + prefill, decode := splitDurations(requestStart, firstTokenAt, requestEnd) + totalPrefill += prefill + totalDecode += decode + results[promptIndex].Tokens = tokens totalGenerated += len(tokens) + + if err := streamError(); err != nil { + results[promptIndex].Err = coreerr.E("rocm.BatchGenerate", fmt.Sprintf("batch prompt %d", promptIndex), err) + } } - m.recordMetrics(len(prompts), totalGenerated, start, start) + m.recordMetricsDurations(totalPromptTokens, totalGenerated, totalPrefill, totalDecode) return results, nil } @@ -213,42 +221,50 @@ func (m *rocmModel) Info() inference.ModelInfo { return m.modelInfo } // Metrics returns performance metrics from the last inference operation. func (m *rocmModel) Metrics() inference.GenerateMetrics { - m.mu.Lock() - defer m.mu.Unlock() - return m.metrics + m.stateMutex.Lock() + defer m.stateMutex.Unlock() + return m.lastMetrics } // Err returns the error from the last Generate/Chat call, if any. func (m *rocmModel) Err() error { - m.mu.Lock() - defer m.mu.Unlock() - return m.lastErr + m.stateMutex.Lock() + defer m.stateMutex.Unlock() + return m.lastError } // Close releases the llama-server subprocess and all associated resources. func (m *rocmModel) Close() error { - return m.srv.stop() + return m.server.stop() } // setServerExitErr stores an appropriate error when the server is dead. func (m *rocmModel) setServerExitErr() { - m.mu.Lock() - defer m.mu.Unlock() - if m.srv.exitErr != nil { - m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited", m.srv.exitErr) + m.stateMutex.Lock() + defer m.stateMutex.Unlock() + if m.server.processExitError != nil { + m.lastError = m.server.wrapProcessError("rocm.setServerExitErr", "server has exited", m.server.processExitError) } else { - m.lastErr = coreerr.E("rocm.setServerExitErr", "server has exited unexpectedly", nil) + m.lastError = coreerr.E("rocm.setServerExitErr", m.server.messageWithProcessOutput("server has exited unexpectedly"), nil) } } // recordMetrics captures timing data from an inference operation. -func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, decodeStart time.Time) { - now := time.Now() - total := now.Sub(start) - decode := now.Sub(decodeStart) - prefill := total - decode +func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, firstTokenAt time.Time) { + prefill, decode := splitDurations(start, firstTokenAt, time.Now()) + m.recordMetricsDurations(promptTokens, generatedTokens, prefill, decode) +} + +func (m *rocmModel) recordMetricsDurations(promptTokens, generatedTokens int, prefill, decode time.Duration) { + if prefill < 0 { + prefill = 0 + } + if decode < 0 { + decode = 0 + } + total := prefill + decode - met := inference.GenerateMetrics{ + metrics := inference.GenerateMetrics{ PromptTokens: promptTokens, GeneratedTokens: generatedTokens, PrefillDuration: prefill, @@ -256,19 +272,75 @@ func (m *rocmModel) recordMetrics(promptTokens, generatedTokens int, start, deco TotalDuration: total, } if prefill > 0 && promptTokens > 0 { - met.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds() + metrics.PrefillTokensPerSec = float64(promptTokens) / prefill.Seconds() } if decode > 0 && generatedTokens > 0 { - met.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds() + metrics.DecodeTokensPerSec = float64(generatedTokens) / decode.Seconds() } // Try to get VRAM stats — best effort. if vram, err := GetVRAMInfo(); err == nil { - met.PeakMemoryBytes = vram.Used - met.ActiveMemoryBytes = vram.Used + metrics.PeakMemoryBytes = vram.Used + metrics.ActiveMemoryBytes = vram.Used } - m.mu.Lock() - m.metrics = met - m.mu.Unlock() + m.stateMutex.Lock() + m.lastMetrics = metrics + m.stateMutex.Unlock() +} + +func (m *rocmModel) clearLastError() { + m.setLastError(nil) +} + +func (m *rocmModel) setLastError(err error) { + m.stateMutex.Lock() + m.lastError = err + m.stateMutex.Unlock() +} + +func newCompletionRequest(prompt string, generateConfig inference.GenerateConfig) llamacpp.CompletionRequest { + return llamacpp.CompletionRequest{ + Prompt: prompt, + MaxTokens: generateConfig.MaxTokens, + Temperature: generateConfig.Temperature, + TopK: generateConfig.TopK, + TopP: generateConfig.TopP, + RepeatPenalty: generateConfig.RepeatPenalty, + } +} + +func newChatRequest(messages []llamacpp.ChatMessage, generateConfig inference.GenerateConfig) llamacpp.ChatRequest { + return llamacpp.ChatRequest{ + Messages: messages, + MaxTokens: generateConfig.MaxTokens, + Temperature: generateConfig.Temperature, + TopK: generateConfig.TopK, + TopP: generateConfig.TopP, + RepeatPenalty: generateConfig.RepeatPenalty, + } +} + +func splitDurations(start, firstTokenAt, end time.Time) (time.Duration, time.Duration) { + if start.IsZero() || end.Before(start) { + return 0, 0 + } + if firstTokenAt.IsZero() || firstTokenAt.Before(start) || firstTokenAt.After(end) { + return end.Sub(start), 0 + } + return firstTokenAt.Sub(start), end.Sub(firstTokenAt) +} + +// llama-server's streaming API does not expose prompt token counts, so metrics +// use a lightweight whitespace-token approximation for prefill throughput. +func approximatePromptTokens(prompt string) int { + return len(strings.Fields(prompt)) +} + +func approximateMessageTokens(messages []inference.Message) int { + total := 0 + for _, msg := range messages { + total += approximatePromptTokens(msg.Content) + } + return total } diff --git a/model_ax7_test.go b/model_ax7_test.go new file mode 100644 index 0000000..90f3fe4 --- /dev/null +++ b/model_ax7_test.go @@ -0,0 +1,380 @@ +//go:build linux && amd64 + +package rocm + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "os/exec" + "reflect" + "strings" + "testing" + "time" + + "dappco.re/go/inference" + "dappco.re/go/rocm/internal/llamacpp" +) + +func TestModel_Model_Generate_Good(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/completions" { + t.Fatalf("r.URL.Path = %q, want /v1/completions", r.URL.Path) + } + writeSSEEvent(w, `{"choices":[{"text":"Hello","finish_reason":null}]}`) + writeSSEEvent(w, `{"choices":[{"text":" ROCm","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + })) + defer ts.Close() + + model := newHTTPBackedModel(ts) + var got []string + for token := range model.Generate(context.Background(), "hello rocm", inference.WithMaxTokens(2)) { + got = append(got, token.Text) + } + if err := model.Err(); err != nil { + t.Fatalf("Err() = %v, want nil", err) + } + if !reflect.DeepEqual(got, []string{"Hello", " ROCm"}) { + t.Fatalf("tokens = %v, want Hello ROCm chunks", got) + } +} + +func TestModel_Model_Generate_Bad(t *testing.T) { + processExited := make(chan struct{}) + close(processExited) + model := &rocmModel{server: &server{processExited: processExited}} + + var count int + for range model.Generate(context.Background(), "hello") { + count++ + } + if count != 0 { + t.Fatalf("token count = %d, want 0", count) + } + if err := model.Err(); err == nil { + t.Fatal("Err() = nil, want server exit error") + } +} + +func TestModel_Model_Generate_Ugly(t *testing.T) { + client := llamacpp.NewClientWithHTTPClient("http://llama.test", &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("data: {\"choices\":[{\"text\":\"partial\"}]}\n\n")), + Request: r, + }, nil + }), + }) + model := newClientBackedModel(client) + + var got []string + for token := range model.Generate(context.Background(), "hello") { + got = append(got, token.Text) + } + if !reflect.DeepEqual(got, []string{"partial"}) { + t.Fatalf("tokens = %v, want partial token", got) + } + if err := model.Err(); err == nil || !strings.Contains(err.Error(), "stream ended before [DONE]") { + t.Fatalf("Err() = %v, want truncated stream error", err) + } +} + +func TestModel_Model_Chat_Good(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Fatalf("r.URL.Path = %q, want /v1/chat/completions", r.URL.Path) + } + writeSSEEvent(w, `{"choices":[{"delta":{"content":"Ready"},"finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + })) + defer ts.Close() + + model := newHTTPBackedModel(ts) + var got []string + for token := range model.Chat(context.Background(), []inference.Message{{Role: "user", Content: "status"}}) { + got = append(got, token.Text) + } + if err := model.Err(); err != nil { + t.Fatalf("Err() = %v, want nil", err) + } + if !reflect.DeepEqual(got, []string{"Ready"}) { + t.Fatalf("tokens = %v, want Ready", got) + } +} + +func TestModel_Model_Chat_Bad(t *testing.T) { + processExited := make(chan struct{}) + close(processExited) + model := &rocmModel{server: &server{processExited: processExited}} + + var count int + for range model.Chat(context.Background(), []inference.Message{{Role: "user", Content: "hello"}}) { + count++ + } + if count != 0 { + t.Fatalf("token count = %d, want 0", count) + } + if err := model.Err(); err == nil { + t.Fatal("Err() = nil, want server exit error") + } +} + +func TestModel_Model_Chat_Ugly(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "chat failed", http.StatusInternalServerError) + })) + defer ts.Close() + + model := newHTTPBackedModel(ts) + var got []string + for token := range model.Chat(context.Background(), []inference.Message{{Role: "user", Content: "hello"}}) { + got = append(got, token.Text) + } + if len(got) != 0 { + t.Fatalf("tokens = %v, want empty", got) + } + if err := model.Err(); err == nil || !strings.Contains(err.Error(), "chat returned 500") { + t.Fatalf("Err() = %v, want chat returned 500", err) + } +} + +func TestModel_Model_Classify_Good(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/completions" { + t.Fatalf("r.URL.Path = %q, want /v1/completions", r.URL.Path) + } + writeSSEEvent(w, `{"choices":[{"text":"positive","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + })) + defer ts.Close() + + model := newHTTPBackedModel(ts) + results, err := model.Classify(context.Background(), []string{"good"}) + if err != nil { + t.Fatalf("Classify: %v", err) + } + if len(results) != 1 || results[0].Token.Text != "positive" { + t.Fatalf("results = %+v, want positive label", results) + } +} + +func TestModel_Model_Classify_Bad(t *testing.T) { + processExited := make(chan struct{}) + close(processExited) + model := &rocmModel{server: &server{processExited: processExited}} + + results, err := model.Classify(context.Background(), []string{"bad"}) + if err == nil { + t.Fatal("Classify error = nil, want server exit error") + } + if results != nil { + t.Fatalf("results = %+v, want nil", results) + } +} + +func TestModel_Model_Classify_Ugly(t *testing.T) { + client := llamacpp.NewClientWithHTTPClient("http://llama.test", &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("data: {\"choices\":[{\"text\":\"partial\"}]}\n\n")), + Request: r, + }, nil + }), + }) + model := newClientBackedModel(client) + + results, err := model.Classify(context.Background(), []string{"edge"}) + if err == nil { + t.Fatal("Classify error = nil, want truncated stream error") + } + if results != nil { + t.Fatalf("results = %+v, want nil", results) + } +} + +func TestModel_Model_BatchGenerate_Good(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSEEvent(w, `{"choices":[{"text":"A","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + })) + defer ts.Close() + + model := newHTTPBackedModel(ts) + results, err := model.BatchGenerate(context.Background(), []string{"one", "two"}, inference.WithMaxTokens(1)) + if err != nil { + t.Fatalf("BatchGenerate: %v", err) + } + if len(results) != 2 || len(results[0].Tokens) != 1 || len(results[1].Tokens) != 1 { + t.Fatalf("results = %+v, want one token per prompt", results) + } +} + +func TestModel_Model_BatchGenerate_Bad(t *testing.T) { + processExited := make(chan struct{}) + close(processExited) + model := &rocmModel{server: &server{processExited: processExited}} + + results, err := model.BatchGenerate(context.Background(), []string{"bad"}) + if err == nil { + t.Fatal("BatchGenerate error = nil, want server exit error") + } + if results != nil { + t.Fatalf("results = %+v, want nil", results) + } +} + +func TestModel_Model_BatchGenerate_Ugly(t *testing.T) { + model := newClientBackedModel(llamacpp.NewClient("http://127.0.0.1:1")) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + results, err := model.BatchGenerate(ctx, []string{"cancelled"}) + if err != nil { + t.Fatalf("BatchGenerate: %v", err) + } + if len(results) != 1 || results[0].Err == nil { + t.Fatalf("results = %+v, want per-prompt cancellation error", results) + } +} + +func TestModel_Model_ModelType_Good(t *testing.T) { + model := &rocmModel{modelType: "gemma3"} + got := model.ModelType() + if got != "gemma3" { + t.Fatalf("ModelType() = %q, want gemma3", got) + } +} + +func TestModel_Model_ModelType_Bad(t *testing.T) { + model := &rocmModel{} + got := model.ModelType() + if got != "" { + t.Fatalf("ModelType() = %q, want empty string", got) + } +} + +func TestModel_Model_ModelType_Ugly(t *testing.T) { + model := &rocmModel{modelType: "vendor.experimental"} + first := model.ModelType() + second := model.ModelType() + if first != second { + t.Fatalf("ModelType() changed from %q to %q", first, second) + } +} + +func TestModel_Model_Info_Good(t *testing.T) { + model := &rocmModel{modelInfo: inference.ModelInfo{Architecture: "llama", NumLayers: 32, QuantBits: 4}} + info := model.Info() + if info.Architecture != "llama" { + t.Fatalf("Info().Architecture = %q, want llama", info.Architecture) + } + if info.NumLayers != 32 || info.QuantBits != 4 { + t.Fatalf("Info() = %+v, want layer and quant metadata", info) + } +} + +func TestModel_Model_Info_Bad(t *testing.T) { + model := &rocmModel{} + info := model.Info() + if info != (inference.ModelInfo{}) { + t.Fatalf("Info() = %+v, want zero value", info) + } +} + +func TestModel_Model_Info_Ugly(t *testing.T) { + model := &rocmModel{modelInfo: inference.ModelInfo{Architecture: "qwen2", NumLayers: 28}} + info := model.Info() + info.Architecture = "mutated" + if model.Info().Architecture != "qwen2" { + t.Fatalf("Info() returned mutable internal state: %+v", model.Info()) + } +} + +func TestModel_Model_Metrics_Good(t *testing.T) { + model := &rocmModel{} + model.recordMetricsDurations(2, 3, 10*time.Millisecond, 20*time.Millisecond) + metrics := model.Metrics() + if metrics.PromptTokens != 2 || metrics.GeneratedTokens != 3 { + t.Fatalf("Metrics() = %+v, want prompt and generated token counts", metrics) + } + if metrics.TotalDuration != 30*time.Millisecond { + t.Fatalf("TotalDuration = %s, want 30ms", metrics.TotalDuration) + } +} + +func TestModel_Model_Metrics_Bad(t *testing.T) { + model := &rocmModel{} + metrics := model.Metrics() + if metrics != (inference.GenerateMetrics{}) { + t.Fatalf("Metrics() = %+v, want zero value", metrics) + } +} + +func TestModel_Model_Metrics_Ugly(t *testing.T) { + model := &rocmModel{} + model.recordMetricsDurations(1, 1, -time.Second, -time.Second) + metrics := model.Metrics() + if metrics.TotalDuration != 0 { + t.Fatalf("TotalDuration = %s, want clamped zero", metrics.TotalDuration) + } +} + +func TestModel_Model_Err_Good(t *testing.T) { + model := &rocmModel{} + sentinel := errors.New("decode failed") + model.setLastError(sentinel) + if !errors.Is(model.Err(), sentinel) { + t.Fatalf("Err() = %v, want sentinel", model.Err()) + } +} + +func TestModel_Model_Err_Bad(t *testing.T) { + model := &rocmModel{} + model.setLastError(errors.New("temporary")) + model.clearLastError() + if model.Err() != nil { + t.Fatalf("Err() = %v, want nil", model.Err()) + } +} + +func TestModel_Model_Err_Ugly(t *testing.T) { + model := &rocmModel{server: &server{processOutput: newProcessOutputCapture(serverProcessOutputLimit)}} + model.setServerExitErr() + if model.Err() == nil { + t.Fatal("Err() = nil, want server exit error") + } +} + +func TestModel_Model_Close_Good(t *testing.T) { + model := &rocmModel{server: &server{processCommand: &exec.Cmd{}}} + err := model.Close() + if err != nil { + t.Fatalf("Close() = %v, want nil", err) + } +} + +func TestModel_Model_Close_Bad(t *testing.T) { + model := &rocmModel{} + defer func() { + if recover() == nil { + t.Fatal("Close() panic = nil, want panic for missing server") + } + }() + _ = model.Close() +} + +func TestModel_Model_Close_Ugly(t *testing.T) { + model := &rocmModel{server: &server{processCommand: &exec.Cmd{}}} + first := model.Close() + second := model.Close() + if first != nil || second != nil { + t.Fatalf("Close() errors = %v, %v; want nil", first, second) + } +} diff --git a/model_test.go b/model_test.go new file mode 100644 index 0000000..c00c762 --- /dev/null +++ b/model_test.go @@ -0,0 +1,354 @@ +//go:build linux && amd64 + +package rocm + +import ( + "context" + "encoding/json" + "io" + "math" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "sync" + "testing" + "time" + + "dappco.re/go/inference" + "dappco.re/go/rocm/internal/llamacpp" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return fn(r) +} + +func newClientBackedModel(client *llamacpp.Client) *rocmModel { + return &rocmModel{ + server: &server{ + llamaClient: client, + processExited: make(chan struct{}), + }, + } +} + +func newHTTPBackedModel(ts *httptest.Server) *rocmModel { + return newClientBackedModel(llamacpp.NewClient(ts.URL)) +} + +func writeSSEEvent(w http.ResponseWriter, payload string) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Write([]byte("data: " + payload + "\n\n")) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } +} + +func TestGenerate_MetricsSplitPrefillAndDecode(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/completions" { + t.Errorf("r.URL.Path = %q, want %q", r.URL.Path, "/v1/completions") + } + + time.Sleep(25 * time.Millisecond) + writeSSEEvent(w, `{"choices":[{"text":"Hello","finish_reason":null}]}`) + time.Sleep(25 * time.Millisecond) + writeSSEEvent(w, `{"choices":[{"text":" world","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + })) + defer ts.Close() + + m := newHTTPBackedModel(ts) + + var got []string + for tok := range m.Generate(context.Background(), "hello world", inference.WithMaxTokens(2)) { + got = append(got, tok.Text) + } + + if err := m.Err(); err != nil { + t.Fatalf("m.Err(): %v", err) + } + want := []string{"Hello", " world"} + if !reflect.DeepEqual(got, want) { + t.Errorf("tokens = %v, want %v", got, want) + } + + met := m.Metrics() + if met.PromptTokens != 2 { + t.Errorf("PromptTokens = %d, want 2", met.PromptTokens) + } + if met.GeneratedTokens != 2 { + t.Errorf("GeneratedTokens = %d, want 2", met.GeneratedTokens) + } + if met.PrefillDuration < 20*time.Millisecond { + t.Errorf("PrefillDuration = %s, want >= 20ms", met.PrefillDuration) + } + if met.DecodeDuration < 20*time.Millisecond { + t.Errorf("DecodeDuration = %s, want >= 20ms", met.DecodeDuration) + } + if met.TotalDuration < 45*time.Millisecond { + t.Errorf("TotalDuration = %s, want >= 45ms", met.TotalDuration) + } + if met.PrefillTokensPerSec <= 0 { + t.Errorf("PrefillTokensPerSec = %v, want > 0", met.PrefillTokensPerSec) + } + if met.DecodeTokensPerSec <= 0 { + t.Errorf("DecodeTokensPerSec = %v, want > 0", met.DecodeTokensPerSec) + } +} + +func TestClassify_AppliesGenerateOptions(t *testing.T) { + var ( + mu sync.Mutex + requests []llamacpp.CompletionRequest + ) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/completions" { + t.Errorf("r.URL.Path = %q, want %q", r.URL.Path, "/v1/completions") + } + + var req llamacpp.CompletionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + mu.Lock() + requests = append(requests, req) + mu.Unlock() + + writeSSEEvent(w, `{"choices":[{"text":"label","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + })) + defer ts.Close() + + m := newHTTPBackedModel(ts) + + results, err := m.Classify( + context.Background(), + []string{"hello world"}, + inference.WithTemperature(0.7), + inference.WithTopK(42), + inference.WithTopP(0.91), + inference.WithRepeatPenalty(1.3), + ) + if err != nil { + t.Fatalf("Classify: %v", err) + } + if len(results) != 1 { + t.Fatalf("len(results) = %d, want 1", len(results)) + } + if results[0].Token.Text != "label" { + t.Errorf("results[0].Token.Text = %q, want %q", results[0].Token.Text, "label") + } + + mu.Lock() + if len(requests) != 1 { + mu.Unlock() + t.Fatalf("len(requests) = %d, want 1", len(requests)) + } + req := requests[0] + mu.Unlock() + + if req.Prompt != "hello world" { + t.Errorf("req.Prompt = %q, want %q", req.Prompt, "hello world") + } + if req.MaxTokens != 1 { + t.Errorf("req.MaxTokens = %d, want 1", req.MaxTokens) + } + if math.Abs(req.Temperature-0.7) > 0.001 { + t.Errorf("req.Temperature = %v, want ~0.7", req.Temperature) + } + if req.TopK != 42 { + t.Errorf("req.TopK = %d, want 42", req.TopK) + } + if math.Abs(req.TopP-0.91) > 0.001 { + t.Errorf("req.TopP = %v, want ~0.91", req.TopP) + } + if math.Abs(req.RepeatPenalty-1.3) > 0.001 { + t.Errorf("req.RepeatPenalty = %v, want ~1.3", req.RepeatPenalty) + } +} + +func TestBatchGenerate_MetricsAggregatePrefillAndDecode(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/completions" { + t.Errorf("r.URL.Path = %q, want %q", r.URL.Path, "/v1/completions") + } + + time.Sleep(15 * time.Millisecond) + writeSSEEvent(w, `{"choices":[{"text":"A","finish_reason":null}]}`) + time.Sleep(15 * time.Millisecond) + writeSSEEvent(w, `{"choices":[{"text":"B","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + })) + defer ts.Close() + + m := newHTTPBackedModel(ts) + + results, err := m.BatchGenerate(context.Background(), []string{"alpha beta", "gamma delta"}, inference.WithMaxTokens(2)) + if err != nil { + t.Fatalf("BatchGenerate: %v", err) + } + if len(results) != 2 { + t.Fatalf("len(results) = %d, want 2", len(results)) + } + if len(results[0].Tokens) != 2 { + t.Fatalf("len(results[0].Tokens) = %d, want 2", len(results[0].Tokens)) + } + if len(results[1].Tokens) != 2 { + t.Fatalf("len(results[1].Tokens) = %d, want 2", len(results[1].Tokens)) + } + + met := m.Metrics() + if met.PromptTokens != 4 { + t.Errorf("PromptTokens = %d, want 4", met.PromptTokens) + } + if met.GeneratedTokens != 4 { + t.Errorf("GeneratedTokens = %d, want 4", met.GeneratedTokens) + } + if met.PrefillDuration < 20*time.Millisecond { + t.Errorf("PrefillDuration = %s, want >= 20ms", met.PrefillDuration) + } + if met.DecodeDuration < 20*time.Millisecond { + t.Errorf("DecodeDuration = %s, want >= 20ms", met.DecodeDuration) + } + if met.TotalDuration < 50*time.Millisecond { + t.Errorf("TotalDuration = %s, want >= 50ms", met.TotalDuration) + } + if met.PrefillTokensPerSec <= 0 { + t.Errorf("PrefillTokensPerSec = %v, want > 0", met.PrefillTokensPerSec) + } + if met.DecodeTokensPerSec <= 0 { + t.Errorf("DecodeTokensPerSec = %v, want > 0", met.DecodeTokensPerSec) + } +} + +func TestClassify_ContextCancelledRecordsMetricsAndWrapsError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var requestCount int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + writeSSEEvent(w, `{"choices":[{"text":"label","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + if requestCount == 1 { + cancel() + } + })) + defer ts.Close() + + m := newHTTPBackedModel(ts) + + results, err := m.Classify(ctx, []string{"hello world", "goodbye world"}) + if err == nil { + t.Fatal("expected error, got nil") + } + if results != nil { + t.Errorf("results = %v, want nil", results) + } + if requestCount != 1 { + t.Errorf("requestCount = %d, want 1", requestCount) + } + if !strings.Contains(err.Error(), "rocm.Classify") { + t.Errorf("err = %v, want contains %q", err, "rocm.Classify") + } + if !strings.Contains(err.Error(), "cancelled before prompt 1") { + t.Errorf("err = %v, want contains %q", err, "cancelled before prompt 1") + } + + metrics := m.Metrics() + if metrics.PromptTokens != 2 { + t.Errorf("PromptTokens = %d, want 2", metrics.PromptTokens) + } + if metrics.GeneratedTokens != 1 { + t.Errorf("GeneratedTokens = %d, want 1", metrics.GeneratedTokens) + } +} + +func TestBatchGenerate_ContextCancelledWrapsPerPromptError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var requestCount int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + writeSSEEvent(w, `{"choices":[{"text":"token","finish_reason":null}]}`) + writeSSEEvent(w, "[DONE]") + if requestCount == 1 { + cancel() + } + })) + defer ts.Close() + + m := newHTTPBackedModel(ts) + + results, err := m.BatchGenerate(ctx, []string{"hello world", "goodbye world"}, inference.WithMaxTokens(1)) + if err != nil { + t.Fatalf("BatchGenerate: %v", err) + } + if len(results) != 2 { + t.Fatalf("len(results) = %d, want 2", len(results)) + } + if requestCount != 1 { + t.Errorf("requestCount = %d, want 1", requestCount) + } + if len(results[0].Tokens) != 1 { + t.Fatalf("len(results[0].Tokens) = %d, want 1", len(results[0].Tokens)) + } + if results[1].Err == nil { + t.Fatal("results[1].Err = nil, want error") + } + if !strings.Contains(results[1].Err.Error(), "rocm.BatchGenerate") { + t.Errorf("results[1].Err = %v, want contains %q", results[1].Err, "rocm.BatchGenerate") + } + if !strings.Contains(results[1].Err.Error(), "cancelled before start") { + t.Errorf("results[1].Err = %v, want contains %q", results[1].Err, "cancelled before start") + } + + metrics := m.Metrics() + if metrics.PromptTokens != 2 { + t.Errorf("PromptTokens = %d, want 2", metrics.PromptTokens) + } + if metrics.GeneratedTokens != 1 { + t.Errorf("GeneratedTokens = %d, want 1", metrics.GeneratedTokens) + } +} + +func TestGenerate_TruncatedStreamSetsLastError(t *testing.T) { + client := llamacpp.NewClientWithHTTPClient("http://llama.test", &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.Path != "/v1/completions" { + t.Errorf("r.URL.Path = %q, want %q", r.URL.Path, "/v1/completions") + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: " + `{"choices":[{"text":"partial","finish_reason":null}]}` + "\n\n", + )), + Request: r, + }, nil + }), + }) + + m := newClientBackedModel(client) + + var got []string + for tok := range m.Generate(context.Background(), "hello", inference.WithMaxTokens(1)) { + got = append(got, tok.Text) + } + + want := []string{"partial"} + if !reflect.DeepEqual(got, want) { + t.Errorf("tokens = %v, want %v", got, want) + } + if err := m.Err(); err == nil { + t.Fatal("m.Err() = nil, want error") + } else if !strings.Contains(err.Error(), "stream ended before [DONE]") { + t.Errorf("m.Err() = %v, want contains %q", err, "stream ended before [DONE]") + } +} diff --git a/register_rocm.go b/register_rocm.go index f7b6e3b..86b1ab1 100644 --- a/register_rocm.go +++ b/register_rocm.go @@ -2,11 +2,15 @@ package rocm -import "forge.lthn.ai/core/go-inference" +import "dappco.re/go/inference" func init() { inference.Register(&rocmBackend{}) } +// if ROCmAvailable() { +// fmt.Println("ROCm code path compiled in") +// } +// // ROCmAvailable reports whether ROCm GPU inference is available. func ROCmAvailable() bool { return true } diff --git a/rocm.go b/rocm.go index bea7178..034f526 100644 --- a/rocm.go +++ b/rocm.go @@ -6,8 +6,8 @@ // # Quick Start // // import ( -// "forge.lthn.ai/core/go-inference" -// _ "forge.lthn.ai/core/go-rocm" // auto-registers ROCm backend +// "dappco.re/go/inference" +// _ "dappco.re/go/rocm" // auto-registers ROCm backend // ) // // m, err := inference.LoadModel("/path/to/model.gguf") diff --git a/rocm_benchmark_test.go b/rocm_benchmark_test.go index b833eaa..b257a42 100644 --- a/rocm_benchmark_test.go +++ b/rocm_benchmark_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "forge.lthn.ai/core/go-inference" + "dappco.re/go/inference" ) // benchModels lists the models to benchmark. diff --git a/rocm_integration_test.go b/rocm_integration_test.go index 1856036..48a6285 100644 --- a/rocm_integration_test.go +++ b/rocm_integration_test.go @@ -11,9 +11,7 @@ import ( "testing" "time" - "forge.lthn.ai/core/go-inference" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "dappco.re/go/inference" ) const testModel = "/data/lem/gguf/LEK-Gemma3-1B-layered-v2-Q5_K_M.gguf" @@ -40,13 +38,19 @@ func TestROCm_LoadAndGenerate(t *testing.T) { skipIfNoModel(t) b := &rocmBackend{} - require.True(t, b.Available()) + if !b.Available() { + t.Fatal("b.Available() = false, want true") + } m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() - assert.Equal(t, "gemma3", m.ModelType()) + if got := m.ModelType(); got != "gemma3" { + t.Errorf("ModelType() = %q, want %q", got, "gemma3") + } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -56,8 +60,12 @@ func TestROCm_LoadAndGenerate(t *testing.T) { tokens = append(tokens, tok.Text) } - require.NoError(t, m.Err()) - require.NotEmpty(t, tokens, "expected at least one token") + if err := m.Err(); err != nil { + t.Fatalf("m.Err(): %v", err) + } + if len(tokens) == 0 { + t.Fatal("expected at least one token") + } full := "" for _, tok := range tokens { @@ -72,7 +80,9 @@ func TestROCm_Chat(t *testing.T) { b := &rocmBackend{} m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -87,8 +97,12 @@ func TestROCm_Chat(t *testing.T) { tokens = append(tokens, tok.Text) } - require.NoError(t, m.Err()) - require.NotEmpty(t, tokens, "expected at least one token") + if err := m.Err(); err != nil { + t.Fatalf("m.Err(): %v", err) + } + if len(tokens) == 0 { + t.Fatal("expected at least one token") + } full := "" for _, tok := range tokens { @@ -103,7 +117,9 @@ func TestROCm_ContextCancellation(t *testing.T) { b := &rocmBackend{} m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -118,7 +134,9 @@ func TestROCm_ContextCancellation(t *testing.T) { } t.Logf("Got %d tokens before cancel", count) - assert.GreaterOrEqual(t, count, 3) + if count < 3 { + t.Errorf("count = %d, want >= 3", count) + } } func TestROCm_GracefulShutdown(t *testing.T) { @@ -127,7 +145,9 @@ func TestROCm_GracefulShutdown(t *testing.T) { b := &rocmBackend{} m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() // Cancel mid-stream. @@ -152,8 +172,12 @@ func TestROCm_GracefulShutdown(t *testing.T) { count2++ } - require.NoError(t, m.Err()) - assert.Greater(t, count2, 0, "expected tokens from second generation after cancel") + if err := m.Err(); err != nil { + t.Fatalf("m.Err(): %v", err) + } + if count2 == 0 { + t.Error("expected tokens from second generation after cancel") + } t.Logf("Second generation: %d tokens", count2) } @@ -163,7 +187,9 @@ func TestROCm_ConcurrentRequests(t *testing.T) { b := &rocmBackend{} m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() const numGoroutines = 3 @@ -197,7 +223,9 @@ func TestROCm_ConcurrentRequests(t *testing.T) { for i, result := range results { t.Logf("Goroutine %d: %s", i, result) - assert.NotEmpty(t, result, "goroutine %d produced no output", i) + if result == "" { + t.Errorf("goroutine %d produced no output", i) + } } } @@ -207,7 +235,9 @@ func TestROCm_Classify(t *testing.T) { b := &rocmBackend{} m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -219,11 +249,17 @@ func TestROCm_Classify(t *testing.T) { } results, err := m.Classify(ctx, prompts) - require.NoError(t, err) - require.Len(t, results, 2) + if err != nil { + t.Fatalf("Classify: %v", err) + } + if len(results) != 2 { + t.Fatalf("len(results) = %d, want 2", len(results)) + } for i, r := range results { - assert.NotEmpty(t, r.Token.Text, "classify result %d should have a token", i) + if r.Token.Text == "" { + t.Errorf("classify result %d should have a token", i) + } t.Logf("Classify %d: %q", i, r.Token.Text) } } @@ -234,7 +270,9 @@ func TestROCm_BatchGenerate(t *testing.T) { b := &rocmBackend{} m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) @@ -246,12 +284,20 @@ func TestROCm_BatchGenerate(t *testing.T) { } results, err := m.BatchGenerate(ctx, prompts, inference.WithMaxTokens(8)) - require.NoError(t, err) - require.Len(t, results, 2) + if err != nil { + t.Fatalf("BatchGenerate: %v", err) + } + if len(results) != 2 { + t.Fatalf("len(results) = %d, want 2", len(results)) + } for i, r := range results { - require.NoError(t, r.Err, "batch result %d error", i) - assert.NotEmpty(t, r.Tokens, "batch result %d should have tokens", i) + if r.Err != nil { + t.Fatalf("batch result %d error: %v", i, r.Err) + } + if len(r.Tokens) == 0 { + t.Errorf("batch result %d should have tokens", i) + } var sb strings.Builder for _, tok := range r.Tokens { @@ -267,14 +313,22 @@ func TestROCm_InfoAndMetrics(t *testing.T) { b := &rocmBackend{} m, err := b.LoadModel(testModel, inference.WithContextLen(2048)) - require.NoError(t, err) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } defer m.Close() // Info should be populated from GGUF metadata. info := m.Info() - assert.Equal(t, "gemma3", info.Architecture) - assert.Greater(t, info.NumLayers, 0, "expected non-zero layer count") - assert.Greater(t, info.QuantBits, 0, "expected non-zero quant bits") + if info.Architecture != "gemma3" { + t.Errorf("Architecture = %q, want %q", info.Architecture, "gemma3") + } + if info.NumLayers == 0 { + t.Error("expected non-zero layer count") + } + if info.QuantBits == 0 { + t.Error("expected non-zero quant bits") + } t.Logf("Info: arch=%s layers=%d quant=%d-bit group=%d", info.Architecture, info.NumLayers, info.QuantBits, info.QuantGroup) @@ -284,12 +338,20 @@ func TestROCm_InfoAndMetrics(t *testing.T) { for range m.Generate(ctx, "Hello", inference.WithMaxTokens(4)) { } - require.NoError(t, m.Err()) + if err := m.Err(); err != nil { + t.Fatalf("m.Err(): %v", err) + } met := m.Metrics() - assert.Greater(t, met.GeneratedTokens, 0, "expected generated tokens") - assert.Greater(t, met.TotalDuration, time.Duration(0), "expected non-zero duration") - assert.Greater(t, met.DecodeTokensPerSec, float64(0), "expected non-zero decode throughput") + if met.GeneratedTokens == 0 { + t.Error("expected generated tokens") + } + if met.TotalDuration == 0 { + t.Error("expected non-zero duration") + } + if met.DecodeTokensPerSec <= 0 { + t.Error("expected non-zero decode throughput") + } t.Logf("Metrics: gen=%d tok, total=%s, decode=%.1f tok/s, vram=%d MiB", met.GeneratedTokens, met.TotalDuration, met.DecodeTokensPerSec, met.ActiveMemoryBytes/(1024*1024)) @@ -302,13 +364,23 @@ func TestROCm_DiscoverModels(t *testing.T) { } models, err := DiscoverModels(dir) - require.NoError(t, err) - require.NotEmpty(t, models, "expected at least one model in %s", dir) + if err != nil { + t.Fatalf("DiscoverModels: %v", err) + } + if len(models) == 0 { + t.Fatalf("expected at least one model in %s", dir) + } for _, m := range models { t.Logf("Found: %s (%s %s %s, ctx=%d)", filepath.Base(m.Path), m.Architecture, m.Parameters, m.Quantisation, m.ContextLen) - assert.NotEmpty(t, m.Architecture) - assert.NotEmpty(t, m.Name) - assert.Greater(t, m.FileSize, int64(0)) + if m.Architecture == "" { + t.Errorf("empty Architecture for %s", m.Path) + } + if m.Name == "" { + t.Errorf("empty Name for %s", m.Path) + } + if m.FileSize <= 0 { + t.Errorf("FileSize = %d for %s, want > 0", m.FileSize, m.Path) + } } } diff --git a/rocm_linux_ax7_test.go b/rocm_linux_ax7_test.go new file mode 100644 index 0000000..7496ae2 --- /dev/null +++ b/rocm_linux_ax7_test.go @@ -0,0 +1,77 @@ +//go:build linux && amd64 + +package rocm + +import ( + "strings" + "testing" +) + +func TestROCm_ROCmAvailable_Good(t *testing.T) { + available := ROCmAvailable() + if !available { + t.Fatal("ROCmAvailable() = false, want true for linux/amd64 build") + } + if ROCmAvailable() != available { + t.Fatal("ROCmAvailable() changed between calls") + } +} + +func TestROCm_ROCmAvailable_Bad(t *testing.T) { + available := ROCmAvailable() + info, err := GetVRAMInfo() + if !available { + t.Fatal("ROCmAvailable() = false, want compiled ROCm path") + } + if err == nil && info.Total == 0 { + t.Fatal("GetVRAMInfo() succeeded with zero total VRAM") + } +} + +func TestROCm_ROCmAvailable_Ugly(t *testing.T) { + first := ROCmAvailable() + second := ROCmAvailable() + if first != second { + t.Fatalf("ROCmAvailable() changed from %v to %v", first, second) + } + if !first { + t.Fatal("ROCmAvailable() = false, want stable true") + } +} + +func TestVRAM_GetVRAMInfo_Good(t *testing.T) { + info, err := GetVRAMInfo() + if err != nil { + if info != (VRAMInfo{}) { + t.Errorf("info = %+v, want zero when error is returned", info) + } + return + } + if info.Total == 0 { + t.Fatal("info.Total = 0, want positive VRAM total") + } + if info.Free > info.Total { + t.Fatalf("info.Free = %d, want <= total %d", info.Free, info.Total) + } +} + +func TestVRAM_GetVRAMInfo_Bad(t *testing.T) { + info, err := GetVRAMInfo() + if err != nil && !strings.Contains(err.Error(), "rocm.GetVRAMInfo") { + t.Fatalf("err = %v, want rocm.GetVRAMInfo scope", err) + } + if err == nil && info.Used > info.Total && info.Free != 0 { + t.Fatalf("info = %+v, want free clamped to zero when used exceeds total", info) + } +} + +func TestVRAM_GetVRAMInfo_Ugly(t *testing.T) { + first, firstErr := GetVRAMInfo() + second, secondErr := GetVRAMInfo() + if (firstErr == nil) != (secondErr == nil) { + t.Fatalf("error stability changed from %v to %v", firstErr, secondErr) + } + if firstErr == nil && (first.Total == 0 || second.Total == 0) { + t.Fatalf("totals = %d, %d; want positive totals", first.Total, second.Total) + } +} diff --git a/rocm_stub.go b/rocm_stub.go index 0947fe7..683a3ed 100644 --- a/rocm_stub.go +++ b/rocm_stub.go @@ -2,12 +2,19 @@ package rocm -import coreerr "forge.lthn.ai/core/go-log" +import coreerr "dappco.re/go/log" +// if !ROCmAvailable() { +// fmt.Println("fall back to CPU or another backend") +// } +// // ROCmAvailable reports whether ROCm GPU inference is available. // Returns false on non-Linux or non-amd64 platforms. func ROCmAvailable() bool { return false } +// _, err := GetVRAMInfo() +// fmt.Println(err) +// // GetVRAMInfo is not available on non-Linux/non-amd64 platforms. func GetVRAMInfo() (VRAMInfo, error) { return VRAMInfo{}, coreerr.E("rocm.GetVRAMInfo", "VRAM monitoring not available on this platform", nil) diff --git a/rocm_stub_ax7_test.go b/rocm_stub_ax7_test.go new file mode 100644 index 0000000..4220869 --- /dev/null +++ b/rocm_stub_ax7_test.go @@ -0,0 +1,74 @@ +//go:build !linux || !amd64 + +package rocm + +import ( + "strings" + "testing" +) + +func TestROCm_ROCmAvailable_Good(t *testing.T) { + available := ROCmAvailable() + if available { + t.Fatal("ROCmAvailable() = true, want false on this platform") + } + if ROCmAvailable() != available { + t.Fatal("ROCmAvailable() changed between calls") + } +} + +func TestROCm_ROCmAvailable_Bad(t *testing.T) { + _, err := GetVRAMInfo() + available := ROCmAvailable() + if available { + t.Fatal("ROCmAvailable() = true, want false when platform stub is active") + } + if err == nil { + t.Fatal("GetVRAMInfo() error = nil, want platform error") + } +} + +func TestROCm_ROCmAvailable_Ugly(t *testing.T) { + first := ROCmAvailable() + second := ROCmAvailable() + if first != second { + t.Fatalf("ROCmAvailable() changed from %v to %v", first, second) + } + if first { + t.Fatal("ROCmAvailable() = true, want false for repeated stub calls") + } +} + +func TestVRAM_GetVRAMInfo_Good(t *testing.T) { + info, err := GetVRAMInfo() + if err == nil { + t.Fatal("GetVRAMInfo() error = nil, want unsupported-platform error") + } + if info != (VRAMInfo{}) { + t.Errorf("info = %+v, want zero VRAMInfo", info) + } +} + +func TestVRAM_GetVRAMInfo_Bad(t *testing.T) { + info, err := GetVRAMInfo() + if err == nil { + t.Fatal("GetVRAMInfo() error = nil, want error") + } + if !strings.Contains(err.Error(), "not available on this platform") { + t.Errorf("err = %v, want platform unavailable message", err) + } + if info.Total != 0 || info.Used != 0 || info.Free != 0 { + t.Errorf("info = %+v, want all fields zero", info) + } +} + +func TestVRAM_GetVRAMInfo_Ugly(t *testing.T) { + first, firstErr := GetVRAMInfo() + second, secondErr := GetVRAMInfo() + if firstErr == nil || secondErr == nil { + t.Fatalf("errors = %v, %v; want both non-nil", firstErr, secondErr) + } + if first != second { + t.Fatalf("GetVRAMInfo() changed from %+v to %+v", first, second) + } +} diff --git a/server.go b/server.go index 071e759..eee399b 100644 --- a/server.go +++ b/server.go @@ -4,32 +4,67 @@ package rocm import ( "context" + // Note: intrinsic - errors.Is/As for stdlib error chain walks; core is downstream. + "errors" + // Note: intrinsic - fmt.Sprintf/Errorf for HTTP error responses; core.Sprintf can replace but server.go is old-style. "fmt" + // Note: intrinsic - net.Listener for the HTTP server; no core equivalent. "net" + // Note: intrinsic - os.Getenv/Stdout; core.Env is downstream of server. "os" + // Note: intrinsic - rocm-smi CLI subprocess; Process primitive unavailable in bare HTTP handler. "os/exec" + // Note: intrinsic - numeric parsing from ROCm output; core has no ParseInt/Atoi. "strconv" + // Note: intrinsic - core helpers are not yet in scope for this repo. "strings" + "sync" + "sync/atomic" "syscall" "time" - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-rocm/internal/llamacpp" + coreerr "dappco.re/go/log" + "dappco.re/go/rocm/internal/llamacpp" +) + +var ( + serverStartupTimeout = 60 * time.Second + serverReadyPollInterval = 100 * time.Millisecond + serverPortAllocator = newDeterministicPortAllocator(serverPortRangeStart, serverPortRangeCount) + // listenLocalTCP lets tests stub port probing without opening real sockets. + listenLocalTCP = net.Listen +) + +const ( + serverProcessOutputLimit = 32 << 10 + serverProcessOutputSummarySize = 1024 + serverPortRangeStart = 38080 + serverPortRangeCount = 256 ) // server manages a llama-server subprocess. type server struct { - cmd *exec.Cmd - port int - client *llamacpp.Client - exited chan struct{} - exitErr error // safe to read only after <-exited + processCommand *exec.Cmd + port int + llamaClient *llamacpp.Client + processExited chan struct{} + processExitError error // safe to read only after <-processExited + processOutput *processOutputCapture +} + +// serverStartConfig keeps llama-server startup settings named instead of positional. +type serverStartConfig struct { + BinaryPath string + ModelPath string + GPULayerCount int + ContextSize int + ParallelSlotCount int } // alive reports whether the llama-server process is still running. func (s *server) alive() bool { select { - case <-s.exited: + case <-s.processExited: return false default: return true @@ -40,10 +75,7 @@ func (s *server) alive() bool { // Checks ROCM_LLAMA_SERVER_PATH first, then PATH. func findLlamaServer() (string, error) { if p := os.Getenv("ROCM_LLAMA_SERVER_PATH"); p != "" { - if _, err := os.Stat(p); err != nil { - return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+p, err) - } - return p, nil + return validateLlamaServerPath(p) } p, err := exec.LookPath("llama-server") if err != nil { @@ -52,25 +84,35 @@ func findLlamaServer() (string, error) { return p, nil } -// freePort asks the kernel for a free TCP port on localhost. -func freePort() (int, error) { - ln, err := net.Listen("tcp", "127.0.0.1:0") +func validateLlamaServerPath(path string) (string, error) { + info, err := os.Stat(path) if err != nil { - return 0, coreerr.E("rocm.freePort", "listen for free port", err) + return "", coreerr.E("rocm.findLlamaServer", "llama-server not found at ROCM_LLAMA_SERVER_PATH="+path, err) + } + if info.IsDir() { + return "", coreerr.E("rocm.findLlamaServer", "ROCM_LLAMA_SERVER_PATH must point to a file", nil) + } + if info.Mode().Perm()&0o111 == 0 { + return "", coreerr.E("rocm.findLlamaServer", "llama-server is not executable at ROCM_LLAMA_SERVER_PATH="+path, nil) } - port := ln.Addr().(*net.TCPAddr).Port - ln.Close() - return port, nil + return path, nil +} + +// freePort walks a deterministic localhost port range and returns the first +// currently-bindable port. +func freePort() (int, error) { + return serverPortAllocator.NextAvailablePort() } // serverEnv returns the environment for the llama-server subprocess. -// Filters any existing HIP_VISIBLE_DEVICES and sets it to 0 to mask the iGPU. -// This is critical — the Ryzen 9 iGPU crashes llama-server if not masked. +// Filters any existing HIP_* settings and sets HIP_VISIBLE_DEVICES=0 to mask +// the iGPU. This is critical — the Ryzen 9 iGPU crashes llama-server if not +// masked, and inherited HIP variables can re-expose multi-GPU state. func serverEnv() []string { environ := os.Environ() env := make([]string, 0, len(environ)+1) for _, e := range environ { - if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { + if strings.HasPrefix(e, "HIP_") { continue } env = append(env, e) @@ -80,124 +122,301 @@ func serverEnv() []string { } // startServer spawns llama-server and waits for it to become ready. -// It selects a free port automatically, retrying up to 3 times if the -// process exits during startup (e.g. port conflict). -func startServer(binary, modelPath string, gpuLayers, ctxSize, parallelSlots int) (*server, error) { - if gpuLayers < 0 { - gpuLayers = 999 +// It selects a free port automatically, retrying up to 3 times if startup +// fails before the health endpoint becomes ready. +func startServer(startConfig serverStartConfig) (*server, error) { + gpuLayerCount := startConfig.GPULayerCount + if gpuLayerCount < 0 { + gpuLayerCount = 999 } const maxAttempts = 3 - var lastErr error + var lastStartupError error - for attempt := range maxAttempts { + for attempt := 0; attempt < maxAttempts; attempt++ { port, err := freePort() if err != nil { return nil, coreerr.E("rocm.startServer", "find free port", err) } - args := []string{ - "--model", modelPath, - "--host", "127.0.0.1", - "--port", strconv.Itoa(port), - "--n-gpu-layers", strconv.Itoa(gpuLayers), - } - if ctxSize > 0 { - args = append(args, "--ctx-size", strconv.Itoa(ctxSize)) - } - if parallelSlots > 0 { - args = append(args, "--parallel", strconv.Itoa(parallelSlots)) - } + commandArguments := llamaServerArguments(startConfig, port, gpuLayerCount) - cmd := exec.Command(binary, args...) - cmd.Env = serverEnv() + outputCapture := newProcessOutputCapture(serverProcessOutputLimit) + processCommand := exec.Command(startConfig.BinaryPath, commandArguments...) + processCommand.Env = serverEnv() + processCommand.Stdout = outputCapture + processCommand.Stderr = outputCapture - if err := cmd.Start(); err != nil { + if err := processCommand.Start(); err != nil { return nil, coreerr.E("rocm.startServer", "start llama-server", err) } s := &server{ - cmd: cmd, - port: port, - client: llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port)), - exited: make(chan struct{}), + processCommand: processCommand, + port: port, + llamaClient: llamacpp.NewClient(fmt.Sprintf("http://127.0.0.1:%d", port)), + processExited: make(chan struct{}), + processOutput: outputCapture, } go func() { - s.exitErr = cmd.Wait() - close(s.exited) + s.processExitError = processCommand.Wait() + close(s.processExited) }() - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), serverStartupTimeout) err = s.waitReady(ctx) cancel() if err == nil { return s, nil } - // Only retry if the process actually exited (e.g. port conflict). - // A timeout means the server is stuck, not a port issue. - select { - case <-s.exited: - _ = s.stop() - lastErr = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err) - continue - default: - _ = s.stop() - return nil, coreerr.E("rocm.startServer", "llama-server not ready", err) + if stopErr := s.stop(); stopErr != nil { + coreerr.Warn("llama-server cleanup after failed startup returned error", "attempt", attempt+1, "err", stopErr) + } + lastStartupError = coreerr.E("rocm.startServer", fmt.Sprintf("attempt %d", attempt+1), err) + if attempt < maxAttempts-1 { + coreerr.Warn("llama-server startup failed; retrying", "attempt", attempt+1, "max_attempts", maxAttempts, "err", lastStartupError) } } - return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastErr) + return nil, coreerr.E("rocm.startServer", fmt.Sprintf("server failed after %d attempts", maxAttempts), lastStartupError) +} + +func llamaServerArguments(startConfig serverStartConfig, port, gpuLayerCount int) []string { + commandArguments := []string{ + "--model", startConfig.ModelPath, + "--host", "127.0.0.1", + "--port", strconv.Itoa(port), + "--n-gpu-layers", strconv.Itoa(gpuLayerCount), + } + if startConfig.ContextSize > 0 { + commandArguments = append(commandArguments, "--ctx-size", strconv.Itoa(startConfig.ContextSize)) + } + if startConfig.ParallelSlotCount > 0 { + commandArguments = append(commandArguments, "--parallel", strconv.Itoa(startConfig.ParallelSlotCount)) + } + return commandArguments } // waitReady polls the health endpoint until the server is ready. func (s *server) waitReady(ctx context.Context) error { - ticker := time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(serverReadyPollInterval) defer ticker.Stop() + var lastHealthError error + for { select { case <-ctx.Done(): - return coreerr.E("server.waitReady", "timeout waiting for llama-server", ctx.Err()) - case <-s.exited: - return coreerr.E("server.waitReady", "llama-server exited before becoming ready", s.exitErr) + if lastHealthError != nil { + return coreerr.E("server.waitReady", s.messageWithProcessOutput("timeout waiting for llama-server"), lastHealthError) + } + return coreerr.E("server.waitReady", s.messageWithProcessOutput("timeout waiting for llama-server"), ctx.Err()) + case <-s.processExited: + return s.wrapProcessError("server.waitReady", "llama-server exited before becoming ready", s.processExitError) case <-ticker.C: - if err := s.client.Health(ctx); err == nil { + if err := s.llamaClient.Health(ctx); err == nil { return nil + } else { + lastHealthError = err } } } } -// stop sends SIGTERM and waits up to 5s, then SIGKILL. +// stop sends SIGTERM and waits up to 5s, then SIGKILL. Exit caused by those +// signals is treated as a successful caller-initiated shutdown. func (s *server) stop() error { - if s.cmd.Process == nil { + if s.processCommand.Process == nil { return nil } // Already exited? select { - case <-s.exited: - return s.exitErr + case <-s.processExited: + if isExpectedStopExitErr(s.processExitError) { + return nil + } + return s.wrapProcessError("server.stop", "llama-server already exited", s.processExitError) default: } // Send SIGTERM for graceful shutdown. - if err := s.cmd.Process.Signal(syscall.SIGTERM); err != nil { + if err := s.processCommand.Process.Signal(syscall.SIGTERM); err != nil { return coreerr.E("server.stop", "sigterm llama-server", err) } // Wait up to 5 seconds for clean exit. select { - case <-s.exited: - return s.exitErr + case <-s.processExited: + if isExpectedStopExitErr(s.processExitError) { + return nil + } + return s.wrapProcessError("server.stop", "llama-server exited after sigterm", s.processExitError) case <-time.After(5 * time.Second): // Force kill. - if err := s.cmd.Process.Kill(); err != nil { + if err := s.processCommand.Process.Kill(); err != nil { return coreerr.E("server.stop", "kill llama-server", err) } - <-s.exited - return s.exitErr + <-s.processExited + if isExpectedStopExitErr(s.processExitError) { + return nil + } + return s.wrapProcessError("server.stop", "llama-server exited after sigkill", s.processExitError) + } +} + +func isExpectedStopExitErr(err error) bool { + if err == nil { + return false + } + + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + return false + } + + status, ok := exitErr.ProcessState.Sys().(syscall.WaitStatus) + if !ok || !status.Signaled() { + return false + } + + switch status.Signal() { + case syscall.SIGTERM, syscall.SIGKILL: + return true + default: + return false + } +} + +func (s *server) messageWithProcessOutput(message string) string { + if s == nil || s.processOutput == nil { + return message + } + output := s.processOutput.Summary() + if output == "" { + return message + } + return message + " (llama-server output: " + output + ")" +} + +func (s *server) wrapProcessError(op, message string, err error) error { + if err == nil { + return nil + } + return coreerr.E(op, s.messageWithProcessOutput(message), err) +} + +type deterministicPortAllocator struct { + basePort int + portCount int + nextPort atomic.Uint64 +} + +func newDeterministicPortAllocator(basePort, portCount int) *deterministicPortAllocator { + return &deterministicPortAllocator{ + basePort: basePort, + portCount: portCount, + } +} + +func (allocator *deterministicPortAllocator) NextAvailablePort() (int, error) { + if allocator == nil || allocator.portCount <= 0 { + return 0, coreerr.E("rocm.freePort", "port allocator is not configured", nil) + } + + lastPort := allocator.basePort + allocator.portCount - 1 + if allocator.basePort <= 0 || lastPort > 65535 { + return 0, coreerr.E("rocm.freePort", fmt.Sprintf("invalid port range %d-%d", allocator.basePort, lastPort), nil) + } + + startIndex := allocator.nextPort.Add(1) - 1 + for scanned := 0; scanned < allocator.portCount; scanned++ { + portIndex := int((startIndex + uint64(scanned)) % uint64(allocator.portCount)) + port := allocator.basePort + portIndex + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + + listener, err := listenLocalTCP("tcp", address) + if err != nil { + continue + } + listener.Close() + + allocator.advancePast(startIndex + uint64(scanned) + 1) + return port, nil + } + + return 0, coreerr.E("rocm.freePort", fmt.Sprintf("no free port in deterministic range %d-%d", allocator.basePort, lastPort), nil) +} + +func (allocator *deterministicPortAllocator) advancePast(candidate uint64) { + for { + current := allocator.nextPort.Load() + if current >= candidate { + return + } + if allocator.nextPort.CompareAndSwap(current, candidate) { + return + } + } +} + +type processOutputCapture struct { + maxBytes int + + mu sync.Mutex + buffer []byte + truncated bool +} + +func newProcessOutputCapture(maxBytes int) *processOutputCapture { + return &processOutputCapture{maxBytes: maxBytes} +} + +func (c *processOutputCapture) Write(p []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + + written := len(p) + if c.maxBytes <= 0 || written == 0 { + return written, nil + } + + c.buffer = append(c.buffer, p...) + if len(c.buffer) > c.maxBytes { + c.buffer = append([]byte(nil), c.buffer[len(c.buffer)-c.maxBytes:]...) + c.truncated = true + } + + return written, nil +} + +func (c *processOutputCapture) Summary() string { + c.mu.Lock() + defer c.mu.Unlock() + + output := strings.TrimSpace(string(c.buffer)) + if output == "" { + return "" + } + + lines := strings.Split(output, "\n") + parts := make([]string, 0, len(lines)) + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts = append(parts, line) + } + + output = strings.Join(parts, " | ") + if len(output) > serverProcessOutputSummarySize { + output = output[:serverProcessOutputSummarySize] + "..." + } + if c.truncated { + return "..." + output } + return output } diff --git a/server_ax7_test.go b/server_ax7_test.go new file mode 100644 index 0000000..cf42aa0 --- /dev/null +++ b/server_ax7_test.go @@ -0,0 +1,132 @@ +//go:build linux && amd64 + +package rocm + +import ( + "errors" + "net" + "strings" + "testing" +) + +func TestServer_OutputCapture_Write_Good(t *testing.T) { + capture := newProcessOutputCapture(32) + n, err := capture.Write([]byte("llama ready\n")) + if err != nil { + t.Fatalf("Write() = %v, want nil", err) + } + if n != len("llama ready\n") { + t.Fatalf("Write() n = %d, want %d", n, len("llama ready\n")) + } + if capture.Summary() != "llama ready" { + t.Fatalf("Summary() = %q, want llama ready", capture.Summary()) + } +} + +func TestServer_OutputCapture_Write_Bad(t *testing.T) { + capture := newProcessOutputCapture(0) + n, err := capture.Write([]byte("ignored")) + if err != nil { + t.Fatalf("Write() = %v, want nil", err) + } + if n != len("ignored") { + t.Fatalf("Write() n = %d, want %d", n, len("ignored")) + } + if capture.Summary() != "" { + t.Fatalf("Summary() = %q, want empty", capture.Summary()) + } +} + +func TestServer_OutputCapture_Write_Ugly(t *testing.T) { + capture := newProcessOutputCapture(5) + n, err := capture.Write([]byte("abcdef")) + if err != nil { + t.Fatalf("Write() = %v, want nil", err) + } + if n != len("abcdef") { + t.Fatalf("Write() n = %d, want %d", n, len("abcdef")) + } + if capture.Summary() != "...bcdef" { + t.Fatalf("Summary() = %q, want truncated tail", capture.Summary()) + } +} + +func TestServer_OutputCapture_Summary_Good(t *testing.T) { + capture := newProcessOutputCapture(128) + _, err := capture.Write([]byte(" first line \n\n second line \n")) + if err != nil { + t.Fatalf("Write() = %v, want nil", err) + } + if capture.Summary() != "first line | second line" { + t.Fatalf("Summary() = %q, want joined trimmed lines", capture.Summary()) + } +} + +func TestServer_OutputCapture_Summary_Bad(t *testing.T) { + capture := newProcessOutputCapture(128) + summary := capture.Summary() + if summary != "" { + t.Fatalf("Summary() = %q, want empty", summary) + } + if capture.truncated { + t.Fatal("new capture should not be marked truncated") + } +} + +func TestServer_OutputCapture_Summary_Ugly(t *testing.T) { + capture := newProcessOutputCapture(serverProcessOutputSummarySize + 128) + _, err := capture.Write([]byte(strings.Repeat("x", serverProcessOutputSummarySize+32))) + if err != nil { + t.Fatalf("Write() = %v, want nil", err) + } + summary := capture.Summary() + if !strings.HasSuffix(summary, "...") { + t.Fatalf("Summary() = %q, want ellipsis suffix", summary) + } +} + +func TestServer_PortAllocator_NextAvailablePort_Good(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + + allocator := newDeterministicPortAllocator(41000, 2) + port, err := allocator.NextAvailablePort() + if err != nil { + t.Fatalf("NextAvailablePort() = %v, want nil", err) + } + if port != 41000 { + t.Fatalf("port = %d, want 41000", port) + } +} + +func TestServer_PortAllocator_NextAvailablePort_Bad(t *testing.T) { + allocator := newDeterministicPortAllocator(0, 2) + port, err := allocator.NextAvailablePort() + if err == nil { + t.Fatal("NextAvailablePort() error = nil, want invalid range error") + } + if port != 0 { + t.Fatalf("port = %d, want 0", port) + } +} + +func TestServer_PortAllocator_NextAvailablePort_Ugly(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + if address == "127.0.0.1:42000" { + return nil, errors.New("occupied") + } + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + + allocator := newDeterministicPortAllocator(42000, 2) + port, err := allocator.NextAvailablePort() + if err != nil { + t.Fatalf("NextAvailablePort() = %v, want nil", err) + } + if port != 42001 { + t.Fatalf("port = %d, want skipped occupied port 42001", port) + } +} diff --git a/server_test.go b/server_test.go index 0ecbc57..224f34c 100644 --- a/server_test.go +++ b/server_test.go @@ -4,50 +4,185 @@ package rocm import ( "context" + "errors" + "net" "os" + "os/exec" + "path/filepath" + "reflect" + "strconv" "strings" "testing" + "time" - "forge.lthn.ai/core/go-inference" - coreerr "forge.lthn.ai/core/go-log" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "dappco.re/go/inference" + coreerr "dappco.re/go/log" ) func TestFindLlamaServer_InPATH(t *testing.T) { // llama-server is at /usr/local/bin/llama-server on this machine. path, err := findLlamaServer() - require.NoError(t, err) - assert.Contains(t, path, "llama-server") + if err != nil { + t.Fatalf("findLlamaServer: %v", err) + } + if !strings.Contains(path, "llama-server") { + t.Errorf("path = %q, want contains %q", path, "llama-server") + } } func TestFindLlamaServer_EnvOverride(t *testing.T) { t.Setenv("ROCM_LLAMA_SERVER_PATH", "/usr/local/bin/llama-server") path, err := findLlamaServer() - require.NoError(t, err) - assert.Equal(t, "/usr/local/bin/llama-server", path) + if err != nil { + t.Fatalf("findLlamaServer: %v", err) + } + if path != "/usr/local/bin/llama-server" { + t.Errorf("path = %q, want %q", path, "/usr/local/bin/llama-server") + } } func TestFindLlamaServer_EnvNotFound(t *testing.T) { t.Setenv("ROCM_LLAMA_SERVER_PATH", "/nonexistent/llama-server") _, err := findLlamaServer() - assert.ErrorContains(t, err, "not found") + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err = %v, want contains %q", err, "not found") + } +} + +func TestFindLlamaServer_EnvNotExecutable(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "llama-server") + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + t.Setenv("ROCM_LLAMA_SERVER_PATH", path) + + _, err := findLlamaServer() + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "not executable") { + t.Errorf("err = %v, want contains %q", err, "not executable") + } } func TestFreePort(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + port, err := freePort() - require.NoError(t, err) - assert.Greater(t, port, 0) - assert.Less(t, port, 65536) + if err != nil { + t.Fatalf("freePort: %v", err) + } + if port <= 0 { + t.Errorf("port = %d, want > 0", port) + } + if port >= 65536 { + t.Errorf("port = %d, want < 65536", port) + } } func TestFreePort_UniquePerCall(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + p1, err := freePort() - require.NoError(t, err) + if err != nil { + t.Fatalf("freePort: %v", err) + } p2, err := freePort() - require.NoError(t, err) - _ = p1 - _ = p2 + if err != nil { + t.Fatalf("freePort: %v", err) + } + if p1 == p2 { + t.Errorf("freePort returned identical ports: %d", p1) + } +} + +func TestDeterministicPortAllocator_AdvancesAcrossCalls(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + + allocator := newDeterministicPortAllocator(41000, 3) + + firstPort, err := allocator.NextAvailablePort() + if err != nil { + t.Fatalf("NextAvailablePort #1: %v", err) + } + + secondPort, err := allocator.NextAvailablePort() + if err != nil { + t.Fatalf("NextAvailablePort #2: %v", err) + } + + if firstPort != 41000 { + t.Errorf("firstPort = %d, want 41000", firstPort) + } + if secondPort != 41001 { + t.Errorf("secondPort = %d, want 41001", secondPort) + } +} + +func TestDeterministicPortAllocator_SkipsOccupiedPort(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + if address == "127.0.0.1:42000" { + return nil, errors.New("port already in use") + } + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + + allocator := newDeterministicPortAllocator(42000, 3) + port, err := allocator.NextAvailablePort() + if err != nil { + t.Fatalf("NextAvailablePort: %v", err) + } + if port != 42001 { + t.Errorf("port = %d, want 42001", port) + } +} + +func TestDeterministicPortAllocator_ReturnsErrorWhenRangeIsExhausted(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + return nil, errors.New("port already in use") + }) + defer restoreListen() + + allocator := newDeterministicPortAllocator(43000, 2) + _, err := allocator.NextAvailablePort() + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "no free port in deterministic range") { + t.Errorf("err = %v, want contains %q", err, "no free port in deterministic range") + } +} + +func TestLlamaServerArguments(t *testing.T) { + args := llamaServerArguments(serverStartConfig{ + ModelPath: "/models/gemma3.gguf", + ContextSize: 2048, + ParallelSlotCount: 4, + }, 38123, 999) + + want := []string{ + "--model", "/models/gemma3.gguf", + "--host", "127.0.0.1", + "--port", "38123", + "--n-gpu-layers", "999", + "--ctx-size", "2048", + "--parallel", "4", + } + if !reflect.DeepEqual(args, want) { + t.Errorf("args = %v, want %v", args, want) + } } func TestServerEnv_HIPVisibleDevices(t *testing.T) { @@ -58,19 +193,27 @@ func TestServerEnv_HIPVisibleDevices(t *testing.T) { hipVals = append(hipVals, e) } } - assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals) + want := []string{"HIP_VISIBLE_DEVICES=0"} + if !reflect.DeepEqual(hipVals, want) { + t.Errorf("hipVals = %v, want %v", hipVals, want) + } } func TestServerEnv_FiltersExistingHIP(t *testing.T) { t.Setenv("HIP_VISIBLE_DEVICES", "1") + t.Setenv("HIP_DEVICE_ORDER", "PCI_BUS_ID") + t.Setenv("HIP_TRACE_API", "1") env := serverEnv() var hipVals []string for _, e := range env { - if strings.HasPrefix(e, "HIP_VISIBLE_DEVICES=") { + if strings.HasPrefix(e, "HIP_") { hipVals = append(hipVals, e) } } - assert.Equal(t, []string{"HIP_VISIBLE_DEVICES=0"}, hipVals) + want := []string{"HIP_VISIBLE_DEVICES=0"} + if !reflect.DeepEqual(hipVals, want) { + t.Errorf("hipVals = %v, want %v", hipVals, want) + } } func TestAvailable(t *testing.T) { @@ -78,61 +221,216 @@ func TestAvailable(t *testing.T) { if _, err := os.Stat("/dev/kfd"); err != nil { t.Skip("no ROCm hardware") } - assert.True(t, b.Available()) + if !b.Available() { + t.Error("b.Available() = false, want true") + } } - func TestServerAlive_Running(t *testing.T) { - s := &server{exited: make(chan struct{})} - assert.True(t, s.alive()) + s := &server{processExited: make(chan struct{})} + if !s.alive() { + t.Error("s.alive() = false, want true") + } } func TestServerAlive_Exited(t *testing.T) { - exited := make(chan struct{}) - close(exited) - s := &server{exited: exited, exitErr: coreerr.E("test", "process killed", nil)} - assert.False(t, s.alive()) + processExited := make(chan struct{}) + close(processExited) + s := &server{processExited: processExited, processExitError: coreerr.E("test", "process killed", nil)} + if s.alive() { + t.Error("s.alive() = true, want false") + } } func TestGenerate_ServerDead(t *testing.T) { - exited := make(chan struct{}) - close(exited) + processOutput := newProcessOutputCapture(serverProcessOutputLimit) + _, _ = processOutput.Write([]byte("fatal: HIP launch failure\n")) + processExited := make(chan struct{}) + close(processExited) s := &server{ - exited: exited, - exitErr: coreerr.E("test", "process killed", nil), + processExited: processExited, + processExitError: coreerr.E("test", "process killed", nil), + processOutput: processOutput, } - m := &rocmModel{srv: s} + m := &rocmModel{server: s} var count int for range m.Generate(context.Background(), "hello") { count++ } - assert.Equal(t, 0, count) - assert.ErrorContains(t, m.Err(), "server has exited") + if count != 0 { + t.Errorf("count = %d, want 0", count) + } + err := m.Err() + if err == nil { + t.Fatal("m.Err() = nil, want error") + } + if !strings.Contains(err.Error(), "server has exited") { + t.Errorf("err = %v, want contains %q", err, "server has exited") + } + if !strings.Contains(err.Error(), "HIP launch failure") { + t.Errorf("err = %v, want contains %q", err, "HIP launch failure") + } } func TestStartServer_RetriesOnProcessExit(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + // /bin/false starts successfully but exits immediately with code 1. // startServer should retry up to 3 times, then fail. - _, err := startServer("/bin/false", "/nonexistent/model.gguf", 999, 0, 0) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed after 3 attempts") + _, err := startServer(serverStartConfig{ + BinaryPath: "/bin/false", + ModelPath: "/nonexistent/model.gguf", + GPULayerCount: 999, + }) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed after 3 attempts") { + t.Errorf("err = %v, want contains %q", err, "failed after 3 attempts") + } +} + +func TestStartServer_RetriesOnStartupTimeout(t *testing.T) { + restoreListen := stubListenLocalTCP(t, func(network, address string) (net.Listener, error) { + return fakeTCPListener{address: address}, nil + }) + defer restoreListen() + + dir := t.TempDir() + binary := filepath.Join(dir, "fake-llama-server") + if err := os.WriteFile(binary, []byte("#!/bin/sh\nsleep 1\n"), 0755); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + oldTimeout := serverStartupTimeout + oldInterval := serverReadyPollInterval + serverStartupTimeout = 50 * time.Millisecond + serverReadyPollInterval = 10 * time.Millisecond + t.Cleanup(func() { + serverStartupTimeout = oldTimeout + serverReadyPollInterval = oldInterval + }) + + _, err := startServer(serverStartConfig{ + BinaryPath: binary, + ModelPath: "/nonexistent/model.gguf", + GPULayerCount: 999, + }) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed after 3 attempts") { + t.Errorf("err = %v, want contains %q", err, "failed after 3 attempts") + } + if !strings.Contains(err.Error(), "timeout waiting for llama-server") { + t.Errorf("err = %v, want contains %q", err, "timeout waiting for llama-server") + } +} + +func TestServerStop_GracefulSignalReturnsNil(t *testing.T) { + processCommand := exec.Command("/bin/sleep", "60") + if err := processCommand.Start(); err != nil { + t.Fatalf("processCommand.Start: %v", err) + } + + s := &server{ + processCommand: processCommand, + processExited: make(chan struct{}), + } + go func() { + s.processExitError = processCommand.Wait() + close(s.processExited) + }() + + if err := s.stop(); err != nil { + t.Fatalf("s.stop(): %v", err) + } + if err := s.stop(); err != nil { + t.Fatalf("s.stop() second call should remain idempotent: %v", err) + } +} + +func TestServerWrapProcessError_IncludesProcessOutput(t *testing.T) { + processOutput := newProcessOutputCapture(serverProcessOutputLimit) + _, _ = processOutput.Write([]byte("HIP runtime exploded\nsecondary detail\n")) + + s := &server{processOutput: processOutput} + + err := s.wrapProcessError("server.waitReady", "llama-server exited before becoming ready", coreerr.E("test", "exit 1", nil)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "HIP runtime exploded") { + t.Errorf("err = %v, want contains %q", err, "HIP runtime exploded") + } + if !strings.Contains(err.Error(), "secondary detail") { + t.Errorf("err = %v, want contains %q", err, "secondary detail") + } } func TestChat_ServerDead(t *testing.T) { - exited := make(chan struct{}) - close(exited) + processExited := make(chan struct{}) + close(processExited) s := &server{ - exited: exited, - exitErr: coreerr.E("test", "process killed", nil), + processExited: processExited, + processExitError: coreerr.E("test", "process killed", nil), } - m := &rocmModel{srv: s} + m := &rocmModel{server: s} msgs := []inference.Message{{Role: "user", Content: "hello"}} var count int for range m.Chat(context.Background(), msgs) { count++ } - assert.Equal(t, 0, count) - assert.ErrorContains(t, m.Err(), "server has exited") + if count != 0 { + t.Errorf("count = %d, want 0", count) + } + err := m.Err() + if err == nil { + t.Fatal("m.Err() = nil, want error") + } + if !strings.Contains(err.Error(), "server has exited") { + t.Errorf("err = %v, want contains %q", err, "server has exited") + } +} + +func stubListenLocalTCP(t *testing.T, stub func(network, address string) (net.Listener, error)) func() { + t.Helper() + + original := listenLocalTCP + listenLocalTCP = stub + return func() { + listenLocalTCP = original + } +} + +type fakeTCPListener struct { + address string +} + +func (listener fakeTCPListener) Accept() (net.Conn, error) { + return nil, errors.New("not implemented in test listener") +} + +func (listener fakeTCPListener) Close() error { return nil } + +func (listener fakeTCPListener) Addr() net.Addr { + host, portText, err := net.SplitHostPort(listener.address) + if err != nil { + return &net.TCPAddr{} + } + + port, err := strconv.Atoi(portText) + if err != nil { + return &net.TCPAddr{} + } + + return &net.TCPAddr{ + IP: net.ParseIP(host), + Port: port, + } } diff --git a/tests/cli/rocm/Taskfile.yaml b/tests/cli/rocm/Taskfile.yaml new file mode 100644 index 0000000..ff16022 --- /dev/null +++ b/tests/cli/rocm/Taskfile.yaml @@ -0,0 +1,25 @@ +version: "3" + +tasks: + default: + deps: [test] + + test: + desc: Validate the go-rocm AX-10 CLI artifact driver. + dir: ../../.. + cmds: + - | + export GOWORK=off + export GOCACHE="${GOCACHE:-/tmp/go-rocm-gocache}" + export GOMODCACHE="${GOMODCACHE:-/tmp/go-rocm-gomodcache}" + mkdir -p "$GOCACHE" "$GOMODCACHE" + bin="$(mktemp -t core-rocm.XXXXXX)" + trap 'rm -f "$bin"' EXIT + go build -o "$bin" ./tests/cli/rocm + "$bin" + + driver: + desc: Run the go-rocm AX-10 driver directly. + dir: ../../.. + cmds: + - go run ./tests/cli/rocm diff --git a/tests/cli/rocm/main.go b/tests/cli/rocm/main.go new file mode 100644 index 0000000..46f5e64 --- /dev/null +++ b/tests/cli/rocm/main.go @@ -0,0 +1,243 @@ +// AX-10 CLI driver for go-rocm. It exercises the public model discovery and +// platform availability surface without requiring ROCm hardware or llama-server. +// +// task -d tests/cli/rocm test +// go run ./tests/cli/rocm +package main + +import ( + "encoding/binary" + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + + "dappco.re/go/rocm" +) + +type ggufKV struct { + key string + value any +} + +func main() { + if err := run(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run() error { + if err := verifyDiscoverModels(); err != nil { + return fmt.Errorf("discover models: %w", err) + } + if err := verifyPlatformSurface(); err != nil { + return fmt.Errorf("platform surface: %w", err) + } + return nil +} + +func verifyDiscoverModels() error { + dir, err := os.MkdirTemp("", "go-rocm-ax10-models-") + if err != nil { + return err + } + defer os.RemoveAll(dir) + + gemmaPath, err := writeGGUF(dir, "gemma3-1b-q5km.gguf", []ggufKV{ + {key: "general.architecture", value: "gemma3"}, + {key: "general.name", value: "AX-10 Gemma3"}, + {key: "general.file_type", value: uint32(17)}, + {key: "general.size_label", value: "1B"}, + {key: "gemma3.context_length", value: uint32(32768)}, + {key: "gemma3.block_count", value: uint32(26)}, + }) + if err != nil { + return err + } + + llamaPath, err := writeGGUF(dir, "llama-8b-q4km.gguf", []ggufKV{ + {key: "general.architecture", value: "llama"}, + {key: "general.name", value: "AX-10 Llama"}, + {key: "general.file_type", value: uint32(15)}, + {key: "general.size_label", value: "8B"}, + {key: "llama.context_length", value: uint32(131072)}, + {key: "llama.block_count", value: uint32(32)}, + }) + if err != nil { + return err + } + + if err := os.WriteFile(filepath.Join(dir, "corrupt.gguf"), []byte("not gguf"), 0644); err != nil { + return err + } + if err := os.WriteFile(filepath.Join(dir, "README.txt"), []byte("not a model"), 0644); err != nil { + return err + } + + models, err := rocm.DiscoverModels(dir) + if err != nil { + return err + } + if len(models) != 2 { + return fmt.Errorf("len(models) = %d, want 2", len(models)) + } + + if err := expectModel(models[0], rocm.ModelInfo{ + Path: gemmaPath, + Architecture: "gemma3", + Name: "AX-10 Gemma3", + Quantisation: "Q5_K_M", + Parameters: "1B", + ContextLen: 32768, + }); err != nil { + return fmt.Errorf("gemma model: %w", err) + } + + if err := expectModel(models[1], rocm.ModelInfo{ + Path: llamaPath, + Architecture: "llama", + Name: "AX-10 Llama", + Quantisation: "Q4_K_M", + Parameters: "8B", + ContextLen: 131072, + }); err != nil { + return fmt.Errorf("llama model: %w", err) + } + + emptyModels, err := rocm.DiscoverModels(filepath.Join(dir, "missing")) + if err != nil { + return err + } + if len(emptyModels) != 0 { + return fmt.Errorf("missing directory returned %d models, want 0", len(emptyModels)) + } + + _, err = rocm.DiscoverModels(filepath.Join(dir, "bad[")) + if err == nil { + return errors.New("bad glob pattern returned nil error") + } + if !strings.Contains(err.Error(), "glob gguf files") { + return fmt.Errorf("bad glob error = %v", err) + } + + return nil +} + +func expectModel(got rocm.ModelInfo, want rocm.ModelInfo) error { + if got.Path != want.Path { + return fmt.Errorf("Path = %q, want %q", got.Path, want.Path) + } + if got.Architecture != want.Architecture { + return fmt.Errorf("Architecture = %q, want %q", got.Architecture, want.Architecture) + } + if got.Name != want.Name { + return fmt.Errorf("Name = %q, want %q", got.Name, want.Name) + } + if got.Quantisation != want.Quantisation { + return fmt.Errorf("Quantisation = %q, want %q", got.Quantisation, want.Quantisation) + } + if got.Parameters != want.Parameters { + return fmt.Errorf("Parameters = %q, want %q", got.Parameters, want.Parameters) + } + if got.ContextLen != want.ContextLen { + return fmt.Errorf("ContextLen = %d, want %d", got.ContextLen, want.ContextLen) + } + if got.FileSize <= 0 { + return fmt.Errorf("FileSize = %d, want > 0", got.FileSize) + } + return nil +} + +func verifyPlatformSurface() error { + compiledForROCm := runtime.GOOS == "linux" && runtime.GOARCH == "amd64" + if got := rocm.ROCmAvailable(); got != compiledForROCm { + return fmt.Errorf("ROCmAvailable() = %v, want %v", got, compiledForROCm) + } + + info, err := rocm.GetVRAMInfo() + if !compiledForROCm { + if err == nil { + return errors.New("GetVRAMInfo() on non-ROCm platform returned nil error") + } + if info != (rocm.VRAMInfo{}) { + return fmt.Errorf("GetVRAMInfo() on non-ROCm platform = %+v, want zero value", info) + } + return nil + } + + if err != nil { + return nil + } + if info.Total == 0 { + return fmt.Errorf("VRAM Total = %d, want > 0", info.Total) + } + if info.Used > info.Total { + return fmt.Errorf("VRAM Used = %d, want <= Total %d", info.Used, info.Total) + } + if info.Free > info.Total { + return fmt.Errorf("VRAM Free = %d, want <= Total %d", info.Free, info.Total) + } + return nil +} + +func writeGGUF(dir, filename string, kvs []ggufKV) (string, error) { + path := filepath.Join(dir, filename) + + file, err := os.Create(path) + if err != nil { + return "", err + } + defer file.Close() + + if err := binary.Write(file, binary.LittleEndian, uint32(0x46554747)); err != nil { + return "", err + } + if err := binary.Write(file, binary.LittleEndian, uint32(3)); err != nil { + return "", err + } + if err := binary.Write(file, binary.LittleEndian, uint64(0)); err != nil { + return "", err + } + if err := binary.Write(file, binary.LittleEndian, uint64(len(kvs))); err != nil { + return "", err + } + + for _, kv := range kvs { + if err := writeKV(file, kv); err != nil { + return "", err + } + } + + return path, nil +} + +func writeKV(file *os.File, kv ggufKV) error { + if err := binary.Write(file, binary.LittleEndian, uint64(len(kv.key))); err != nil { + return err + } + if _, err := file.Write([]byte(kv.key)); err != nil { + return err + } + + switch value := kv.value.(type) { + case string: + if err := binary.Write(file, binary.LittleEndian, uint32(8)); err != nil { + return err + } + if err := binary.Write(file, binary.LittleEndian, uint64(len(value))); err != nil { + return err + } + _, err := file.Write([]byte(value)) + return err + case uint32: + if err := binary.Write(file, binary.LittleEndian, uint32(4)); err != nil { + return err + } + return binary.Write(file, binary.LittleEndian, value) + default: + return fmt.Errorf("unsupported GGUF value type %T for %q", kv.value, kv.key) + } +} diff --git a/vram.go b/vram.go index 9f6d1da..4919011 100644 --- a/vram.go +++ b/vram.go @@ -3,17 +3,24 @@ package rocm import ( + // Note: os: os.ReadFile for sysfs memory files; core.Fs() does not model sysfs "os" + // Note: path/filepath: filepath.Glob/Join for sysfs path walking; no core equivalent for sysfs paths "path/filepath" + // Note: strconv: numeric parsing of sysfs values; no core.ParseInt "strconv" + // Note: strings: trimming sysfs output whitespace; core.* not in scope for this repo "strings" - coreerr "forge.lthn.ai/core/go-log" + coreerr "dappco.re/go/log" ) -// GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs. -// It identifies the dGPU by selecting the card with the largest VRAM total, -// which avoids hardcoding card numbers (e.g. card0=iGPU, card1=dGPU on Ryzen). +// info, err := GetVRAMInfo() +// fmt.Printf("%d MiB free\n", info.Free>>20) +// +// GetVRAMInfo reads VRAM usage for the discrete GPU from sysfs. It identifies +// the dGPU by selecting the card with the largest VRAM total, which avoids +// hardcoding card numbers (e.g. card0=iGPU, card1=dGPU on Ryzen). // // Note: total and used are read non-atomically from sysfs; transient // inconsistencies are possible under heavy allocation churn. diff --git a/vram_test.go b/vram_test.go index 36c8eca..ca07b3b 100644 --- a/vram_test.go +++ b/vram_test.go @@ -6,42 +6,52 @@ import ( "os" "path/filepath" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestReadSysfsUint64(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test_value") - require.NoError(t, os.WriteFile(path, []byte("17163091968\n"), 0644)) + if err := os.WriteFile(path, []byte("17163091968\n"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } val, err := readSysfsUint64(path) - require.NoError(t, err) - assert.Equal(t, uint64(17163091968), val) + if err != nil { + t.Fatalf("readSysfsUint64: %v", err) + } + if val != uint64(17163091968) { + t.Errorf("readSysfsUint64 = %d, want 17163091968", val) + } } func TestReadSysfsUint64_NotFound(t *testing.T) { - _, err := readSysfsUint64("/nonexistent/path") - assert.Error(t, err) + if _, err := readSysfsUint64("/nonexistent/path"); err == nil { + t.Error("expected error for missing path, got nil") + } } func TestReadSysfsUint64_InvalidContent(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "bad_value") - require.NoError(t, os.WriteFile(path, []byte("not-a-number\n"), 0644)) + if err := os.WriteFile(path, []byte("not-a-number\n"), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } - _, err := readSysfsUint64(path) - assert.Error(t, err) + if _, err := readSysfsUint64(path); err == nil { + t.Error("expected error for non-numeric content, got nil") + } } func TestReadSysfsUint64_EmptyFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "empty_value") - require.NoError(t, os.WriteFile(path, []byte(""), 0644)) + if err := os.WriteFile(path, []byte(""), 0644); err != nil { + t.Fatalf("WriteFile: %v", err) + } - _, err := readSysfsUint64(path) - assert.Error(t, err) + if _, err := readSysfsUint64(path); err == nil { + t.Error("expected error for empty file, got nil") + } } func TestGetVRAMInfo(t *testing.T) { @@ -51,7 +61,13 @@ func TestGetVRAMInfo(t *testing.T) { } // On this machine, the dGPU (RX 7800 XT) has ~16GB VRAM. - assert.Greater(t, info.Total, uint64(8*1024*1024*1024), "expected dGPU with >8GB VRAM") - assert.Greater(t, info.Used, uint64(0), "expected some VRAM in use") - assert.Equal(t, info.Total-info.Used, info.Free, "Free should equal Total-Used") + if info.Total <= uint64(8*1024*1024*1024) { + t.Errorf("Total = %d, expected dGPU with >8GB VRAM", info.Total) + } + if info.Used == 0 { + t.Error("Used = 0, expected some VRAM in use") + } + if info.Total-info.Used != info.Free { + t.Errorf("Free = %d, want Total-Used = %d", info.Free, info.Total-info.Used) + } }