Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
403 changes: 403 additions & 0 deletions ax7_test.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"reflect"
"slices"

"dappco.re/go/core"
core "dappco.re/go"
)

// for m := range inference.Discover("/Volumes/Data/models") {
Expand Down
14 changes: 8 additions & 6 deletions discover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"slices"
"testing"

"dappco.re/go/core"
core "dappco.re/go"
)

// --- test helpers for discover ---
Expand Down Expand Up @@ -144,6 +144,8 @@ func TestDiscover_Good_EmptyDir(t *testing.T) {
func TestDiscover_Bad_NonexistentDir(t *testing.T) {
models := slices.Collect(Discover("/nonexistent/path/that/should/not/exist"))
checkEmpty(t, models)
checkLen(t, models, 0)
checkNil(t, models)
}

func TestDiscover_Bad_NoSafetensors(t *testing.T) {
Expand Down Expand Up @@ -313,7 +315,7 @@ func TestDiscover_Good_QuantizationConfigFallback(t *testing.T) {
checkEqual(t, 128, models[0].QuantGroup)
}

func TestDiscover_Good_RecursiveNestedModels(t *testing.T) {
func TestDiscover_Good_RecursiveDeepModels(t *testing.T) {
base := t.TempDir()
createModelDir(t, core.Path(base, "models"), map[string]any{
"model_type": "parent",
Expand All @@ -329,7 +331,7 @@ func TestDiscover_Good_RecursiveNestedModels(t *testing.T) {
}, 1)

models := slices.Collect(Discover(base))
require.Len(t, models, 4)
checkLen(t, models, 4)

gotParent := false
gotNested := false
Expand All @@ -341,8 +343,8 @@ func TestDiscover_Good_RecursiveNestedModels(t *testing.T) {
gotNested = true
}
}
assert.True(t, gotParent, "nested model directories should be discovered")
assert.True(t, gotNested, "deeply nested model directories should be discovered")
checkTrue(t, gotParent)
checkTrue(t, gotNested)
}

func TestDiscover_Good_RecursiveEarlyBreak(t *testing.T) {
Expand All @@ -359,5 +361,5 @@ func TestDiscover_Good_RecursiveEarlyBreak(t *testing.T) {
count++
break
}
assert.Equal(t, 1, count, "breaking from Discover should stop further traversal immediately")
checkEqual(t, 1, count)
}
11 changes: 2 additions & 9 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@ module dappco.re/go/inference

go 1.26.0

require github.com/stretchr/testify v1.11.1
require dappco.re/go v0.9.0

require dappco.re/go/core v0.8.0-alpha.1

require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace dappco.re/go => ../go
11 changes: 0 additions & 11 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,11 +0,0 @@
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
2 changes: 1 addition & 1 deletion inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import (
"slices"
"time"

"dappco.re/go/core"
core "dappco.re/go"
)

// for tok := range m.Generate(ctx, prompt) {
Expand Down
34 changes: 8 additions & 26 deletions inference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"sync" // Note: test-only
"testing"

"dappco.re/go/core"
core "dappco.re/go"
)

// --- test helpers ---
Expand Down Expand Up @@ -192,20 +192,6 @@ func TestInference_All_Good_SortedOrder(t *testing.T) {
checkEqual(t, []string{"alpha", "beta"}, names)
}

func TestInference_All_Good_SortedOrder(t *testing.T) {
resetBackends(t)

Register(&stubBackend{name: "beta", available: true})
Register(&stubBackend{name: "alpha", available: true})

var names []string
for name := range All() {
names = append(names, name)
}

assert.Equal(t, []string{"alpha", "beta"}, names)
}

func TestInference_All_Good_Empty(t *testing.T) {
resetBackends(t)

Expand Down Expand Up @@ -276,17 +262,6 @@ func TestInference_Default_Good_AlphabeticalFallback(t *testing.T) {
checkEqual(t, "alpha", b.Name())
}

func TestInference_Default_Good_AlphabeticalFallback(t *testing.T) {
resetBackends(t)

Register(&stubBackend{name: "zeta", available: true})
Register(&stubBackend{name: "alpha", available: true})

b, err := Default()
require.NoError(t, err)
assert.Equal(t, "alpha", b.Name(), "fallback should be deterministic across non-preferred backends")
}

func TestInference_Default_Good_PriorityOrder(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -510,6 +485,9 @@ func TestInference_LoadModel_Bad_BackendReturnsNilModel(t *testing.T) {
func TestInference_InterfaceCompliance_Good(t *testing.T) {
var _ Backend = (*stubBackend)(nil)
var _ TextModel = (*stubTextModel)(nil)
backend := &stubBackend{name: "compile", available: true}
checkEqual(t, "compile", backend.Name())
checkTrue(t, backend.Available())
}

// --- AttentionSnapshot ---
Expand All @@ -533,6 +511,10 @@ func TestInference_AttentionSnapshot_Good(t *testing.T) {

func TestInference_AttentionInspectorCompliance_Good(t *testing.T) {
var _ AttentionInspector = (*mockInspector)(nil)
inspector := &mockInspector{}
snap, err := inspector.InspectAttention(context.Background(), "hello")
checkNoError(t, err)
checkEqual(t, 28, snap.NumLayers)
}

type mockInspector struct{ stubTextModel }
Expand Down
26 changes: 26 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ func TestOptions_WithTemperature_Good(t *testing.T) {
func TestOptions_WithTemperature_Bad(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithTemperature(-0.5)})
checkInDelta(t, float32(-0.5), cfg.Temperature, 0.0001)
checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens)
checkFalse(t, cfg.ReturnLogits)
}

// --- WithTopK ---
Expand All @@ -115,6 +117,8 @@ func TestOptions_WithTopK_Good(t *testing.T) {
func TestOptions_WithTopK_Bad(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithTopK(-1)})
checkEqual(t, -1, cfg.TopK)
checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens)
checkFalse(t, cfg.ReturnLogits)
}

// --- WithTopP ---
Expand All @@ -141,11 +145,15 @@ func TestOptions_WithTopP_Good(t *testing.T) {
func TestOptions_WithTopP_Bad(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithTopP(-0.1)})
checkInDelta(t, float32(-0.1), cfg.TopP, 0.0001)
checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens)
checkNil(t, cfg.StopTokens)
}

func TestOptions_WithTopP_Ugly(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithTopP(1.5)})
checkInDelta(t, float32(1.5), cfg.TopP, 0.0001)
checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens)
checkFalse(t, cfg.ReturnLogits)
}

// --- WithStopTokens ---
Expand All @@ -165,6 +173,8 @@ func TestOptions_WithStopTokens_Good(t *testing.T) {
func TestOptions_WithStopTokens_Bad(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens()})
checkNil(t, cfg.StopTokens)
checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens)
checkEqual(t, DefaultGenerateConfig().RepeatPenalty, cfg.RepeatPenalty)
}

func TestOptions_WithStopTokens_Ugly(t *testing.T) {
Expand Down Expand Up @@ -207,23 +217,31 @@ func TestOptions_WithRepeatPenalty_Good(t *testing.T) {
func TestOptions_WithRepeatPenalty_Bad(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithRepeatPenalty(-1.0)})
checkInDelta(t, float32(-1.0), cfg.RepeatPenalty, 0.0001)
checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens)
checkNil(t, cfg.StopTokens)
}

func TestOptions_WithRepeatPenalty_Ugly(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithRepeatPenalty(10.0)})
checkInDelta(t, float32(10.0), cfg.RepeatPenalty, 0.0001)
checkEqual(t, DefaultGenerateConfig().Temperature, cfg.Temperature)
checkFalse(t, cfg.ReturnLogits)
}

// --- WithLogits ---

func TestOptions_WithLogits_Good(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithLogits()})
checkTrue(t, cfg.ReturnLogits)
checkEqual(t, DefaultGenerateConfig().MaxTokens, cfg.MaxTokens)
checkEqual(t, DefaultGenerateConfig().RepeatPenalty, cfg.RepeatPenalty)
}

func TestOptions_WithLogits_Good_DefaultIsFalse(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(64)})
checkFalse(t, cfg.ReturnLogits)
checkEqual(t, 64, cfg.MaxTokens)
checkEqual(t, DefaultGenerateConfig().RepeatPenalty, cfg.RepeatPenalty)
}

// --- ApplyGenerateOpts ---
Expand Down Expand Up @@ -366,6 +384,8 @@ func TestOptions_WithBackend_Good(t *testing.T) {
func TestOptions_WithBackend_Bad(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{WithBackend("")})
checkEqual(t, "", cfg.Backend)
checkEqual(t, -1, cfg.GPULayers)
checkEqual(t, 0, cfg.ContextLen)
}

// --- WithContextLen ---
Expand Down Expand Up @@ -411,6 +431,8 @@ func TestOptions_WithGPULayers_Good(t *testing.T) {
func TestOptions_WithGPULayers_Ugly(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(0)})
checkEqual(t, 0, cfg.GPULayers)
checkEqual(t, "", cfg.Backend)
checkEqual(t, 0, cfg.ContextLen)
}

// --- WithParallelSlots ---
Expand Down Expand Up @@ -526,11 +548,15 @@ func TestOptions_WithAdapterPath_Good(t *testing.T) {
func TestOptions_WithAdapterPath_Bad(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{WithAdapterPath("")})
checkEqual(t, "", cfg.AdapterPath)
checkEqual(t, -1, cfg.GPULayers)
checkEqual(t, 0, cfg.ContextLen)
}

func TestOptions_WithAdapterPath_Good_DefaultIsEmpty(t *testing.T) {
cfg := ApplyLoadOpts(nil)
checkEqual(t, "", cfg.AdapterPath)
checkEqual(t, "", cfg.Backend)
checkEqual(t, -1, cfg.GPULayers)
}

func TestOptions_WithAdapterPath_Good_OtherFieldsUnchanged(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion training.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package inference

import (
"dappco.re/go/core"
core "dappco.re/go"
)

// inference.LoRAConfig{Rank: 16, Alpha: 32, TargetKeys: []string{"q_proj", "k_proj", "v_proj"}}
Expand Down
13 changes: 13 additions & 0 deletions training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ func TestTraining_LoadTrainable_Good_ExplicitBackend(t *testing.T) {

func TestTraining_TrainableModel_Good_InterfaceCompliance(t *testing.T) {
var _ TrainableModel = (*stubTrainableModel)(nil)
model := &stubTrainableModel{}
checkEqual(t, 26, model.NumLayers())
checkNil(t, model.ApplyLoRA(DefaultLoRAConfig()))
}

// --- Ugly: edge cases ---
Expand Down Expand Up @@ -142,28 +145,38 @@ func TestTraining_DefaultLoRAConfig_Good_TargetKeysIndependent(t *testing.T) {
func TestTraining_LoRAConfig_Bad_ZeroRank(t *testing.T) {
cfg := LoRAConfig{Rank: 0, Alpha: 16, TargetKeys: []string{"q_proj"}}
checkEqual(t, 0, cfg.Rank)
checkInDelta(t, float32(16), cfg.Alpha, 0.0001)
checkEqual(t, []string{"q_proj"}, cfg.TargetKeys)
}

func TestTraining_LoRAConfig_Bad_NegativeRank(t *testing.T) {
cfg := LoRAConfig{Rank: -8, Alpha: 16, TargetKeys: []string{"q_proj"}}
checkEqual(t, -8, cfg.Rank)
checkInDelta(t, float32(16), cfg.Alpha, 0.0001)
checkEqual(t, []string{"q_proj"}, cfg.TargetKeys)
}

func TestTraining_LoRAConfig_Bad_ZeroAlpha(t *testing.T) {
cfg := LoRAConfig{Rank: 8, Alpha: 0, TargetKeys: []string{"q_proj"}}
checkInDelta(t, float32(0), cfg.Alpha, 0.0001)
checkEqual(t, 8, cfg.Rank)
checkEqual(t, []string{"q_proj"}, cfg.TargetKeys)
}

// --- LoRAConfig Ugly: atypical but valid configurations ---

func TestTraining_LoRAConfig_Ugly_EmptyTargetKeys(t *testing.T) {
cfg := LoRAConfig{Rank: 8, Alpha: 16, TargetKeys: []string{}}
checkEmpty(t, cfg.TargetKeys)
checkEqual(t, 8, cfg.Rank)
checkInDelta(t, float32(16), cfg.Alpha, 0.0001)
}

func TestTraining_LoRAConfig_Ugly_NilTargetKeys(t *testing.T) {
cfg := LoRAConfig{Rank: 8, Alpha: 16}
checkNil(t, cfg.TargetKeys)
checkEqual(t, 8, cfg.Rank)
checkInDelta(t, float32(16), cfg.Alpha, 0.0001)
}

func TestTraining_LoRAConfig_Ugly_BFloat16WithHighRank(t *testing.T) {
Expand Down