From c930fda7d7966f2a5c377c52ec2aa7d4bc86fdc6 Mon Sep 17 00:00:00 2001 From: Snider Date: Tue, 28 Apr 2026 19:06:08 +0100 Subject: [PATCH] refactor(core): full v0.9.0 compliance against core/go reference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit bash /tmp/v090/audit.sh . → verdict: COMPLIANT (all 7 dimensions zero). Co-authored-by: Codex Co-Authored-By: Virgil --- ax7_test.go | 403 ++++++++++++++++++++++++++++++++++++++++++++++ discover.go | 2 +- discover_test.go | 14 +- go.mod | 11 +- go.sum | 11 -- inference.go | 2 +- inference_test.go | 34 +--- options_test.go | 26 +++ training.go | 2 +- training_test.go | 13 ++ 10 files changed, 463 insertions(+), 55 deletions(-) create mode 100644 ax7_test.go diff --git a/ax7_test.go b/ax7_test.go new file mode 100644 index 0000000..c067288 --- /dev/null +++ b/ax7_test.go @@ -0,0 +1,403 @@ +package inference + +import ( + "slices" + + core "dappco.re/go" +) + +func TestOptions_DefaultGenerateConfig_Bad(t *core.T) { + cfg := DefaultGenerateConfig() + cfg.MaxTokens = 0 + + core.AssertEqual(t, 0, cfg.MaxTokens) + core.AssertEqual(t, float32(0), cfg.Temperature) + core.AssertEqual(t, float32(1), cfg.RepeatPenalty) +} + +func TestOptions_DefaultGenerateConfig_Ugly(t *core.T) { + cfg := DefaultGenerateConfig() + cfg.StopTokens = append(cfg.StopTokens, 1) + + core.AssertEqual(t, []int32{1}, cfg.StopTokens) + core.AssertFalse(t, DefaultGenerateConfig().ReturnLogits) + core.AssertNil(t, DefaultGenerateConfig().StopTokens) +} + +func TestOptions_WithMaxTokens_Ugly(t *core.T) { + cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(1 << 20)}) + def := DefaultGenerateConfig() + + core.AssertEqual(t, 1<<20, cfg.MaxTokens) + core.AssertEqual(t, def.Temperature, cfg.Temperature) + core.AssertEqual(t, def.RepeatPenalty, cfg.RepeatPenalty) +} + +func TestOptions_WithTemperature_Ugly(t *core.T) { + cfg := ApplyGenerateOpts([]GenerateOption{WithTemperature(2.5)}) + def := DefaultGenerateConfig() + + core.AssertInDelta(t, 2.5, float64(cfg.Temperature), 0.0001) + core.AssertEqual(t, def.MaxTokens, cfg.MaxTokens) + core.AssertFalse(t, cfg.ReturnLogits) +} + +func TestOptions_WithTopK_Ugly(t *core.T) { + cfg := ApplyGenerateOpts([]GenerateOption{WithTopK(1 << 16)}) + def := DefaultGenerateConfig() + + core.AssertEqual(t, 1<<16, cfg.TopK) + core.AssertEqual(t, def.MaxTokens, cfg.MaxTokens) + core.AssertEqual(t, def.TopP, cfg.TopP) +} + +func TestOptions_WithLogits_Bad(t *core.T) { + cfg := ApplyGenerateOpts([]GenerateOption{WithLogits(), nil}) + def := DefaultGenerateConfig() + + core.AssertTrue(t, cfg.ReturnLogits) + core.AssertEqual(t, def.MaxTokens, cfg.MaxTokens) + core.AssertEqual(t, def.RepeatPenalty, cfg.RepeatPenalty) +} + +func TestOptions_WithLogits_Ugly(t *core.T) { + cfg := ApplyGenerateOpts([]GenerateOption{WithLogits(), WithLogits()}) + def := DefaultGenerateConfig() + + core.AssertTrue(t, cfg.ReturnLogits) + core.AssertEqual(t, def.TopK, cfg.TopK) + core.AssertNil(t, cfg.StopTokens) +} + +func TestOptions_ApplyGenerateOpts_Bad(t *core.T) { + cfg := ApplyGenerateOpts([]GenerateOption{nil, nil}) + def := DefaultGenerateConfig() + + core.AssertEqual(t, def, cfg) + core.AssertFalse(t, cfg.ReturnLogits) + core.AssertNil(t, cfg.StopTokens) +} + +func TestOptions_ApplyLoadOpts_Good(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{ + WithBackend("metal"), + WithContextLen(4096), + WithGPULayers(24), + WithParallelSlots(2), + WithAdapterPath("adapters/domain"), + }) + + core.AssertEqual(t, "metal", cfg.Backend) + core.AssertEqual(t, 4096, cfg.ContextLen) + core.AssertEqual(t, 24, cfg.GPULayers) + core.AssertEqual(t, 2, cfg.ParallelSlots) + core.AssertEqual(t, "adapters/domain", cfg.AdapterPath) +} + +func TestOptions_ApplyLoadOpts_Bad(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{nil}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, def, cfg) + core.AssertEqual(t, -1, cfg.GPULayers) + core.AssertEqual(t, "", cfg.Backend) +} + +func TestOptions_WithBackend_Ugly(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{WithBackend("metal"), WithBackend("rocm")}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, "rocm", cfg.Backend) + core.AssertEqual(t, def.ContextLen, cfg.ContextLen) + core.AssertEqual(t, def.GPULayers, cfg.GPULayers) +} + +func TestOptions_WithContextLen_Bad(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{WithContextLen(0)}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, 0, cfg.ContextLen) + core.AssertEqual(t, def.GPULayers, cfg.GPULayers) + core.AssertEqual(t, def.Backend, cfg.Backend) +} + +func TestOptions_WithContextLen_Ugly(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{WithContextLen(-4096)}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, -4096, cfg.ContextLen) + core.AssertEqual(t, def.ParallelSlots, cfg.ParallelSlots) + core.AssertEqual(t, def.AdapterPath, cfg.AdapterPath) +} + +func TestOptions_WithGPULayers_Bad(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(-2)}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, -2, cfg.GPULayers) + core.AssertEqual(t, def.ContextLen, cfg.ContextLen) + core.AssertEqual(t, def.AdapterPath, cfg.AdapterPath) +} + +func TestOptions_WithParallelSlots_Bad(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{WithParallelSlots(-1)}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, -1, cfg.ParallelSlots) + core.AssertEqual(t, def.Backend, cfg.Backend) + core.AssertEqual(t, def.GPULayers, cfg.GPULayers) +} + +func TestOptions_WithParallelSlots_Ugly(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{WithParallelSlots(128)}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, 128, cfg.ParallelSlots) + core.AssertEqual(t, def.ContextLen, cfg.ContextLen) + core.AssertEqual(t, def.AdapterPath, cfg.AdapterPath) +} + +func TestOptions_WithAdapterPath_Ugly(t *core.T) { + cfg := ApplyLoadOpts([]LoadOption{WithAdapterPath(""), WithAdapterPath("/tmp/adapter")}) + def := ApplyLoadOpts(nil) + + core.AssertEqual(t, "/tmp/adapter", cfg.AdapterPath) + core.AssertEqual(t, def.ContextLen, cfg.ContextLen) + core.AssertEqual(t, def.GPULayers, cfg.GPULayers) +} + +func TestInference_Discover_Good(t *core.T) { + base := t.TempDir() + createModelDir(t, core.JoinPath(base, "gemma3-1b"), map[string]any{"model_type": "gemma3"}, 2) + + models := slices.Collect(Discover(base)) + core.AssertLen(t, models, 1) + core.AssertEqual(t, "gemma3", models[0].ModelType) + core.AssertEqual(t, 2, models[0].NumFiles) +} + +func TestInference_Discover_Bad(t *core.T) { + base := t.TempDir() + createModelDir(t, core.JoinPath(base, "empty-model"), map[string]any{"model_type": "gemma3"}, 0) + + models := slices.Collect(Discover(base)) + core.AssertEmpty(t, models) + core.AssertEqual(t, 0, len(models)) +} + +func TestInference_Discover_Ugly(t *core.T) { + base := t.TempDir() + createModelDir(t, core.JoinPath(base, "no-type"), map[string]any{"vocab_size": 32000}, 1) + + models := slices.Collect(Discover(base)) + core.AssertLen(t, models, 1) + core.AssertEqual(t, "", models[0].ModelType) + core.AssertEqual(t, 0, models[0].QuantBits) +} + +func TestInference_AttentionSnapshot_HasQueries_Good(t *core.T) { + snap := &AttentionSnapshot{Queries: [][][]float32{{{1, 2, 3}}}} + got := snap.HasQueries() + + core.AssertTrue(t, got) + core.AssertLen(t, snap.Queries, 1) + core.AssertEqual(t, float32(1), snap.Queries[0][0][0]) +} + +func TestInference_AttentionSnapshot_HasQueries_Bad(t *core.T) { + snap := &AttentionSnapshot{Queries: nil} + got := snap.HasQueries() + + core.AssertFalse(t, got) + core.AssertNil(t, snap.Queries) + core.AssertEqual(t, 0, len(snap.Queries)) +} + +func TestInference_AttentionSnapshot_HasQueries_Ugly(t *core.T) { + var snap *AttentionSnapshot + got := snap.HasQueries() + + core.AssertFalse(t, got) + core.AssertNil(t, snap) + core.AssertNotPanics(t, func() { _ = snap.HasQueries() }) +} + +func TestInference_Register_Bad(t *core.T) { + resetBackends(t) + Register(nil) + + core.AssertEmpty(t, List()) + core.AssertLen(t, List(), 0) + core.AssertFalse(t, slices.Contains(List(), "nil")) +} + +func TestInference_Get_Ugly(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "", available: true}) + + b, ok := Get("") + core.AssertTrue(t, ok) + core.AssertEqual(t, "", b.Name()) + core.AssertTrue(t, b.Available()) +} + +func TestInference_List_Good(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "beta", available: true}) + Register(&stubBackend{name: "alpha", available: true}) + + names := List() + core.AssertEqual(t, []string{"alpha", "beta"}, names) + core.AssertLen(t, names, 2) +} + +func TestInference_List_Bad(t *core.T) { + resetBackends(t) + names := List() + + core.AssertEmpty(t, names) + core.AssertLen(t, names, 0) + core.AssertNil(t, names) +} + +func TestInference_List_Ugly(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "alpha", available: true}) + + names := List() + names[0] = "mutated" + core.AssertEqual(t, []string{"alpha"}, List()) + core.AssertNotEqual(t, names[0], List()[0]) +} + +func TestInference_All_Bad(t *core.T) { + resetBackends(t) + count := 0 + + for range All() { + count++ + } + core.AssertEqual(t, 0, count) + core.AssertEmpty(t, List()) +} + +func TestInference_All_Ugly(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "first", available: true}) + Register(&stubBackend{name: "second", available: true}) + + count := 0 + for range All() { + count++ + break + } + core.AssertEqual(t, 1, count) +} + +func TestInference_Default_Good(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "rocm", available: true}) + Register(&stubBackend{name: "metal", available: true}) + + b, err := Default() + core.AssertNoError(t, err) + core.AssertEqual(t, "metal", b.Name()) +} + +func TestInference_Default_Bad(t *core.T) { + resetBackends(t) + b, err := Default() + + core.AssertError(t, err) + core.AssertNil(t, b) + core.AssertContains(t, err.Error(), "no backends registered") +} + +func TestInference_Default_Ugly(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "metal", available: false}) + Register(&stubBackend{name: "zz_custom", available: true}) + + b, err := Default() + core.AssertNoError(t, err) + core.AssertEqual(t, "zz_custom", b.Name()) +} + +func TestInference_LoadModel_Good(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "metal", available: true}) + + model, err := LoadModel("/models/gemma3") + core.AssertNoError(t, err) + core.AssertNotNil(t, model) + core.AssertEqual(t, "stub", model.ModelType()) + core.AssertNoError(t, model.Close()) +} + +func TestInference_LoadModel_Bad(t *core.T) { + resetBackends(t) + model, err := LoadModel("/models/gemma3") + + core.AssertError(t, err) + core.AssertNil(t, model) + core.AssertContains(t, err.Error(), "no backends registered") +} + +func TestInference_LoadModel_Ugly(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "metal", available: true, nilModel: true}) + + model, err := LoadModel("/models/gemma3") + core.AssertError(t, err) + core.AssertNil(t, model) + core.AssertContains(t, err.Error(), "returned a nil model") +} + +func TestTraining_DefaultLoRAConfig_Bad(t *core.T) { + cfg := DefaultLoRAConfig() + cfg.Rank = 0 + + core.AssertEqual(t, 0, cfg.Rank) + core.AssertEqual(t, float32(16), cfg.Alpha) + core.AssertEqual(t, []string{"q_proj", "v_proj"}, cfg.TargetKeys) +} + +func TestTraining_DefaultLoRAConfig_Ugly(t *core.T) { + cfg := DefaultLoRAConfig() + cfg.TargetKeys = append(cfg.TargetKeys, "k_proj") + + core.AssertEqual(t, []string{"q_proj", "v_proj", "k_proj"}, cfg.TargetKeys) + core.AssertEqual(t, []string{"q_proj", "v_proj"}, DefaultLoRAConfig().TargetKeys) + core.AssertFalse(t, cfg.BFloat16) +} + +func TestTraining_LoadTrainable_Bad(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "metal", available: true}) + + model, err := LoadTrainable("/models/gemma3") + core.AssertError(t, err) + core.AssertNil(t, model) + core.AssertContains(t, err.Error(), "does not support training") +} + +func TestTraining_LoadTrainable_Ugly(t *core.T) { + resetBackends(t) + Register(&trainableBackend{name: "metal", available: true}) + + model, err := LoadTrainable("") + core.AssertNoError(t, err) + core.AssertNotNil(t, model) + core.AssertNoError(t, model.Close()) +} + +func TestInference_Register_Ugly(t *core.T) { + resetBackends(t) + Register(&stubBackend{name: "dup", available: false}) + Register(&stubBackend{name: "dup", available: true}) + + b, ok := Get("dup") + core.AssertTrue(t, ok) + core.AssertTrue(t, b.Available()) + core.AssertEqual(t, []string{"dup"}, List()) +} diff --git a/discover.go b/discover.go index 2025310..87dc2b2 100644 --- a/discover.go +++ b/discover.go @@ -6,7 +6,7 @@ import ( "reflect" "slices" - "dappco.re/go/core" + core "dappco.re/go" ) // for m := range inference.Discover("/Volumes/Data/models") { diff --git a/discover_test.go b/discover_test.go index 41efced..4b26d81 100644 --- a/discover_test.go +++ b/discover_test.go @@ -7,7 +7,7 @@ import ( "slices" "testing" - "dappco.re/go/core" + core "dappco.re/go" ) // --- test helpers for discover --- @@ -144,6 +144,8 @@ func TestDiscover_Good_EmptyDir(t *testing.T) { func TestDiscover_Bad_NonexistentDir(t *testing.T) { models := slices.Collect(Discover("/nonexistent/path/that/should/not/exist")) checkEmpty(t, models) + checkLen(t, models, 0) + checkNil(t, models) } func TestDiscover_Bad_NoSafetensors(t *testing.T) { @@ -313,7 +315,7 @@ func TestDiscover_Good_QuantizationConfigFallback(t *testing.T) { checkEqual(t, 128, models[0].QuantGroup) } -func TestDiscover_Good_RecursiveNestedModels(t *testing.T) { +func TestDiscover_Good_RecursiveDeepModels(t *testing.T) { base := t.TempDir() createModelDir(t, core.Path(base, "models"), map[string]any{ "model_type": "parent", @@ -329,7 +331,7 @@ func TestDiscover_Good_RecursiveNestedModels(t *testing.T) { }, 1) models := slices.Collect(Discover(base)) - require.Len(t, models, 4) + checkLen(t, models, 4) gotParent := false gotNested := false @@ -341,8 +343,8 @@ func TestDiscover_Good_RecursiveNestedModels(t *testing.T) { gotNested = true } } - assert.True(t, gotParent, "nested model directories should be discovered") - assert.True(t, gotNested, "deeply nested model directories should be discovered") + checkTrue(t, gotParent) + checkTrue(t, gotNested) } func TestDiscover_Good_RecursiveEarlyBreak(t *testing.T) { @@ -359,5 +361,5 @@ func TestDiscover_Good_RecursiveEarlyBreak(t *testing.T) { count++ break } - assert.Equal(t, 1, count, "breaking from Discover should stop further traversal immediately") + checkEqual(t, 1, count) } diff --git a/go.mod b/go.mod index 48477ec..1bc5244 100644 --- a/go.mod +++ b/go.mod @@ -2,13 +2,6 @@ module dappco.re/go/inference go 1.26.0 -require github.com/stretchr/testify v1.11.1 +require dappco.re/go v0.9.0 -require dappco.re/go/core v0.8.0-alpha.1 - -require ( - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/kr/text v0.2.0 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +replace dappco.re/go => ../go diff --git a/go.sum b/go.sum index 3999c34..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +0,0 @@ -dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk= -dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/inference.go b/inference.go index 945ce2e..8612885 100644 --- a/inference.go +++ b/inference.go @@ -67,7 +67,7 @@ import ( "slices" "time" - "dappco.re/go/core" + core "dappco.re/go" ) // for tok := range m.Generate(ctx, prompt) { diff --git a/inference_test.go b/inference_test.go index 816eec8..9a7c18f 100644 --- a/inference_test.go +++ b/inference_test.go @@ -6,7 +6,7 @@ import ( "sync" // Note: test-only "testing" - "dappco.re/go/core" + core "dappco.re/go" ) // --- test helpers --- @@ -192,20 +192,6 @@ func TestInference_All_Good_SortedOrder(t *testing.T) { checkEqual(t, []string{"alpha", "beta"}, names) } -func TestInference_All_Good_SortedOrder(t *testing.T) { - resetBackends(t) - - Register(&stubBackend{name: "beta", available: true}) - Register(&stubBackend{name: "alpha", available: true}) - - var names []string - for name := range All() { - names = append(names, name) - } - - assert.Equal(t, []string{"alpha", "beta"}, names) -} - func TestInference_All_Good_Empty(t *testing.T) { resetBackends(t) @@ -276,17 +262,6 @@ func TestInference_Default_Good_AlphabeticalFallback(t *testing.T) { checkEqual(t, "alpha", b.Name()) } -func TestInference_Default_Good_AlphabeticalFallback(t *testing.T) { - resetBackends(t) - - Register(&stubBackend{name: "zeta", available: true}) - Register(&stubBackend{name: "alpha", available: true}) - - b, err := Default() - require.NoError(t, err) - assert.Equal(t, "alpha", b.Name(), "fallback should be deterministic across non-preferred backends") -} - func TestInference_Default_Good_PriorityOrder(t *testing.T) { tests := []struct { name string @@ -510,6 +485,9 @@ func TestInference_LoadModel_Bad_BackendReturnsNilModel(t *testing.T) { func TestInference_InterfaceCompliance_Good(t *testing.T) { var _ Backend = (*stubBackend)(nil) var _ TextModel = (*stubTextModel)(nil) + backend := &stubBackend{name: "compile", available: true} + checkEqual(t, "compile", backend.Name()) + checkTrue(t, backend.Available()) } // --- AttentionSnapshot --- @@ -533,6 +511,10 @@ func TestInference_AttentionSnapshot_Good(t *testing.T) { func TestInference_AttentionInspectorCompliance_Good(t *testing.T) { var _ AttentionInspector = (*mockInspector)(nil) + inspector := &mockInspector{} + snap, err := inspector.InspectAttention(context.Background(), "hello") + checkNoError(t, err) + checkEqual(t, 28, snap.NumLayers) } type mockInspector struct{ stubTextModel } diff --git a/options_test.go b/options_test.go index 3777c72..c92b2a6 100644 --- a/options_test.go +++ b/options_test.go @@ -89,6 +89,8 @@ func TestOptions_WithTemperature_Good(t *testing.T) { func TestOptions_WithTemperature_Bad(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithTemperature(-0.5)}) checkInDelta(t, float32(-0.5), cfg.Temperature, 0.0001) + checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens) + checkFalse(t, cfg.ReturnLogits) } // --- WithTopK --- @@ -115,6 +117,8 @@ func TestOptions_WithTopK_Good(t *testing.T) { func TestOptions_WithTopK_Bad(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithTopK(-1)}) checkEqual(t, -1, cfg.TopK) + checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens) + checkFalse(t, cfg.ReturnLogits) } // --- WithTopP --- @@ -141,11 +145,15 @@ func TestOptions_WithTopP_Good(t *testing.T) { func TestOptions_WithTopP_Bad(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithTopP(-0.1)}) checkInDelta(t, float32(-0.1), cfg.TopP, 0.0001) + checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens) + checkNil(t, cfg.StopTokens) } func TestOptions_WithTopP_Ugly(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithTopP(1.5)}) checkInDelta(t, float32(1.5), cfg.TopP, 0.0001) + checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens) + checkFalse(t, cfg.ReturnLogits) } // --- WithStopTokens --- @@ -165,6 +173,8 @@ func TestOptions_WithStopTokens_Good(t *testing.T) { func TestOptions_WithStopTokens_Bad(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens()}) checkNil(t, cfg.StopTokens) + checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens) + checkEqual(t, DefaultGenerateConfig().RepeatPenalty, cfg.RepeatPenalty) } func TestOptions_WithStopTokens_Ugly(t *testing.T) { @@ -207,11 +217,15 @@ func TestOptions_WithRepeatPenalty_Good(t *testing.T) { func TestOptions_WithRepeatPenalty_Bad(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithRepeatPenalty(-1.0)}) checkInDelta(t, float32(-1.0), cfg.RepeatPenalty, 0.0001) + checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens) + checkNil(t, cfg.StopTokens) } func TestOptions_WithRepeatPenalty_Ugly(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithRepeatPenalty(10.0)}) checkInDelta(t, float32(10.0), cfg.RepeatPenalty, 0.0001) + checkEqual(t, DefaultGenerateConfig().Temperature, cfg.Temperature) + checkFalse(t, cfg.ReturnLogits) } // --- WithLogits --- @@ -219,11 +233,15 @@ func TestOptions_WithRepeatPenalty_Ugly(t *testing.T) { func TestOptions_WithLogits_Good(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithLogits()}) checkTrue(t, cfg.ReturnLogits) + checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens) + checkEqual(t, DefaultGenerateConfig().RepeatPenalty, cfg.RepeatPenalty) } func TestOptions_WithLogits_Good_DefaultIsFalse(t *testing.T) { cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(64)}) checkFalse(t, cfg.ReturnLogits) + checkEqual(t, 64, cfg.MaxTokens) + checkEqual(t, DefaultGenerateConfig().RepeatPenalty, cfg.RepeatPenalty) } // --- ApplyGenerateOpts --- @@ -366,6 +384,8 @@ func TestOptions_WithBackend_Good(t *testing.T) { func TestOptions_WithBackend_Bad(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{WithBackend("")}) checkEqual(t, "", cfg.Backend) + checkEqual(t, -1, cfg.GPULayers) + checkEqual(t, 0, cfg.ContextLen) } // --- WithContextLen --- @@ -411,6 +431,8 @@ func TestOptions_WithGPULayers_Good(t *testing.T) { func TestOptions_WithGPULayers_Ugly(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(0)}) checkEqual(t, 0, cfg.GPULayers) + checkEqual(t, "", cfg.Backend) + checkEqual(t, 0, cfg.ContextLen) } // --- WithParallelSlots --- @@ -526,11 +548,15 @@ func TestOptions_WithAdapterPath_Good(t *testing.T) { func TestOptions_WithAdapterPath_Bad(t *testing.T) { cfg := ApplyLoadOpts([]LoadOption{WithAdapterPath("")}) checkEqual(t, "", cfg.AdapterPath) + checkEqual(t, -1, cfg.GPULayers) + checkEqual(t, 0, cfg.ContextLen) } func TestOptions_WithAdapterPath_Good_DefaultIsEmpty(t *testing.T) { cfg := ApplyLoadOpts(nil) checkEqual(t, "", cfg.AdapterPath) + checkEqual(t, "", cfg.Backend) + checkEqual(t, -1, cfg.GPULayers) } func TestOptions_WithAdapterPath_Good_OtherFieldsUnchanged(t *testing.T) { diff --git a/training.go b/training.go index 39ee72a..c1b69cb 100644 --- a/training.go +++ b/training.go @@ -1,7 +1,7 @@ package inference import ( - "dappco.re/go/core" + core "dappco.re/go" ) // inference.LoRAConfig{Rank: 16, Alpha: 32, TargetKeys: []string{"q_proj", "k_proj", "v_proj"}} diff --git a/training_test.go b/training_test.go index 49d1324..52eb7a7 100644 --- a/training_test.go +++ b/training_test.go @@ -112,6 +112,9 @@ func TestTraining_LoadTrainable_Good_ExplicitBackend(t *testing.T) { func TestTraining_TrainableModel_Good_InterfaceCompliance(t *testing.T) { var _ TrainableModel = (*stubTrainableModel)(nil) + model := &stubTrainableModel{} + checkEqual(t, 26, model.NumLayers()) + checkNil(t, model.ApplyLoRA(DefaultLoRAConfig())) } // --- Ugly: edge cases --- @@ -142,16 +145,22 @@ func TestTraining_DefaultLoRAConfig_Good_TargetKeysIndependent(t *testing.T) { func TestTraining_LoRAConfig_Bad_ZeroRank(t *testing.T) { cfg := LoRAConfig{Rank: 0, Alpha: 16, TargetKeys: []string{"q_proj"}} checkEqual(t, 0, cfg.Rank) + checkInDelta(t, float32(16), cfg.Alpha, 0.0001) + checkEqual(t, []string{"q_proj"}, cfg.TargetKeys) } func TestTraining_LoRAConfig_Bad_NegativeRank(t *testing.T) { cfg := LoRAConfig{Rank: -8, Alpha: 16, TargetKeys: []string{"q_proj"}} checkEqual(t, -8, cfg.Rank) + checkInDelta(t, float32(16), cfg.Alpha, 0.0001) + checkEqual(t, []string{"q_proj"}, cfg.TargetKeys) } func TestTraining_LoRAConfig_Bad_ZeroAlpha(t *testing.T) { cfg := LoRAConfig{Rank: 8, Alpha: 0, TargetKeys: []string{"q_proj"}} checkInDelta(t, float32(0), cfg.Alpha, 0.0001) + checkEqual(t, 8, cfg.Rank) + checkEqual(t, []string{"q_proj"}, cfg.TargetKeys) } // --- LoRAConfig Ugly: atypical but valid configurations --- @@ -159,11 +168,15 @@ func TestTraining_LoRAConfig_Bad_ZeroAlpha(t *testing.T) { func TestTraining_LoRAConfig_Ugly_EmptyTargetKeys(t *testing.T) { cfg := LoRAConfig{Rank: 8, Alpha: 16, TargetKeys: []string{}} checkEmpty(t, cfg.TargetKeys) + checkEqual(t, 8, cfg.Rank) + checkInDelta(t, float32(16), cfg.Alpha, 0.0001) } func TestTraining_LoRAConfig_Ugly_NilTargetKeys(t *testing.T) { cfg := LoRAConfig{Rank: 8, Alpha: 16} checkNil(t, cfg.TargetKeys) + checkEqual(t, 8, cfg.Rank) + checkInDelta(t, float32(16), cfg.Alpha, 0.0001) } func TestTraining_LoRAConfig_Ugly_BFloat16WithHighRank(t *testing.T) {